ggaaooppeenngg

为什么计算机科学是无限的但生命是有限的

Sequence Parallelism

假设有4个chunk,切四份。

初始化状态,每个GPU都有自己的 Qn Kn,可以计算出对应的注意力矩阵,然后类似AllReduce的方式传递切分的K。

第一步环形传递K,然后再算一次注意力矩阵。

第二步环形传递K,然后再算一次注意力矩阵。

第三步全部传完,得到完整的Sn。

然后 Sn 和 Vn 的计算也是类似的,经过三次环形传递Vn,然后每一份可以单独和小s的那一份做乘法。

所以K和V的传播都要经历 3 次(N-1)的集合通信。

LLM推理的核心在于KVCache的调度。

  1. 尽可能多次重用KV缓存,以减少所需的计算资源;
  2. 每批次最大化token数量,从而改善Model FLOPs Utilization (MFU)。

如果从远程内存获取KVCache,会增加数据传输时间,从而延长TTFT(Time To First Token)。因此,当本地KVCache的增量计算时间少于传输时间时,可以复用本地的KVCache,即使它不是最匹配的。而增大batch意味着系统处理的大批量数据,导致TBT(Token Between Token)延长,可以将负载均衡到低负载的Decode Instance。

架构

Mooncake的架构图主要分为三个部分:Prefill Instance,Decode Instance,Conductor。

  1. Cache-aware Prefill Scheduler:负责调度Request到Prefill Instance,主要考虑load和KVCache的复用率。
  2. KVCache Balance Scheduler:负责从匹配最多前缀的P2P传输KVCache到Instance(Decode和Prefill)。
  3. Load-balance Decoding Scheduler:负责负载均衡调度Request到Decode Instance。

Prefill Instance要满足TTFT SLO,最小化MFU,保证KVCache < DRAM。
Decode Instance要满足TBT SLO,保证KVCache < VRAM。
Inter-Node Transfer基于RDMA的P2P,这也是一个较大的开销。

Mooncake的方法总结如下:

  1. 转移可重用的KVCache,将尽可能多的可重用KVCache转移至Prefill Instance,减少增量计算的时间。
  2. Prefill Instance Pool分层并分块处理,并持续输出给对应的Decode Instance。分层指的是Layer-wise KVCache的异步保存,分块指的是Chunked Pipeline Parallelism。
  3. 独立的Decode Instance Pool加载KVCache,通过连续批处理解码tokens。

Mooncake的主要特点是将prefill和decode拆开,并调度KVCache块。

Reject Policy:如果一个请求不能在服务水平内完成其完整的执行,那么就应该尽早拒绝这个请求,基于这个理念需要设计一些拒绝策略,被称作Overloaded-Scheduling。

KVCache的复制

KVCache的调度主要是利用KV Cache(VRAM,DRAM),利用RDMA带宽。

下图是一个Prefill和Decode分离的计算过程。

如果了解vLLM中的prefill和decode以及管理block的方法,这个图其实很简单。

首先通过Hash判断block是否相同,例如很多系统提示词都是一样的,这部分的复用率很高。

Prefill Instance已经有了ABCDE(这里是一个P2P的过程,但我看开源的版本有个KVCache Store的WIP,不知道后面会不会有一个中心化的KVCache Store的组件)。然后计算了FGHI,存入了KV Cache(在CPU mem上),论文里面提到这个prefill在超过prefill_chunk tokens数量会做chunked prefill。

接着通过Messenger以RDMA的方式发给Decode Instance。Decode Instance基于ABCDEFGHI的prompt对应的KV Cache开始decode的过程。

根据请求模式,它可以使用缓存淘汰算法,如LRU(最近最少使用),LFU(最不常用的),或基于请求特征的算法。这些KVCache块在CPU和GPU之间的传输由一个独立的(GPUDirect)RDMA组件Messenger处理。这种架构还使我们能够为外部用户提供KVCache缓存API,从而实现更高的缓存重用性。

Mooncake已经开源了他的代码,目前只有Transfer Engine。

基于这个架构,Conductor的主要功能是:

  1. 根据当前的KVCache分布和工作负载,分发请求。
  2. 复制或交换某些KVCache块,以便于未来推理。如果某些块的数据在未来被频繁访问,Conductor可能会将其复制到其他节点上,从而提高推理效率。

Mooncake的一个争论点是,是否需要在存在chunked prefill的情况下采用这种分离架构。毕竟,chunked prefill可以填补许多pipeline中的气泡,并且能让prefill和decode节点相对统一,只需要关心一种instance,对于scheduler比较友好。

  1. 不分离的优点:

    • 所有节点被视为平等,使调度更简单;
    • 将chunked prefill内联到解码批处理中可以提高解码批次的计算强度,从而提高MFU。
  2. 分离的优点:

    • 长文本的跨节点并行和VRAM的节省。长文本输入是输出的10倍甚至100倍,对于相同的模型来说,prefill需要多节点配置才能满足显存需求。prefill阶段可以进行layer-wise prefill,每次保存大量KVCache,而decode阶段每次只需保存一个KVCache。因此,prefill阶段可以通过layer-wise prefill来减少VRAM占用。

是这么理解么?异步的Store KVCache可以节省保存的时间,但这是Prefill和Decode分离的理由么?Decode阶段应该是不保存KVCache?

然而,经过仔细考虑,论文决定保持Mooncake的分离架构。只有在请求的prefill可以不进行chunking且不影响TBT SLO的情况下,才会将其内联到解码批次中。我们这样决定的主要原因有两个:

  1. Prefill节点需要不同的跨节点并行设置来处理长上下文 (§5.1)。

  2. 这为节省VRAM提供了独特的机会 (§5.2)。

  3. 大模型需要部署在多机上,进行TP后,每一层都需要进行一次基于RDMA的reduce,这个过程开销巨大。虽然有一些Sequence Parallelism的方法,但效果并不理想,且无法避免跨节点通信。而Mooncake采用的是CPP(Chunked Parallelism Pipeline),将序列按prefill_chunk大小切分,交给prefill pool的不同节点,这些节点被切分成更小的节点池(pipelined prefill node group)。

疑问:他们是pipe的不同部分?还是完全对等的?目前感觉是PP是分layer做Pipe,而CPP是sequence分chunked做pipe。24引用的论文中提到的Sequence Pipeline可以再看一下,应该对理解这个有帮助。

  1. Layer-wise prefill,这有点像airllm项目,在计算过程中动态加载KVCache。在每次注意力计算时,KVCache是异步加载的,计算当前层时可以异步加载下一层,并且当前层结束后可以异步保存当前层。论文中认为KVCache的保存时间可以被完全省略(相较于加载计算保存的线性循环)。这样也可以降低VRAM的占用。

调度算法

  1. 选择Prefill实例

    • 如果Prefill节点上缓存了足够的前缀(由kvcache_balancing_threshold控制),则选择预估TTFT最小的实例:TTFT = min(T_queue + T_prefill)
    • 如果Prefill节点上缓存不足,则选择TTFT = min(T_queue + T_prefill + T_transfer)最小的实例,其中T_transfer指的是有最长匹配的KVCache的实例拷贝到当前实例的预估时间。
  2. 选择Decode实例

    • 通过负载均衡的方式预估TBT。
    • 如果TBT和TTFT不满足SLO,则拒绝请求,并触发KVCache的传输。
  3. 预测模型

    • 预估模型用于预测传输时间和决策传输。
    • 数据传输时间难以预测,因为它不仅取决于数据大小,还依赖于当前网络状态,特别是当发送节点处于拥塞状态时。
  4. KVCache复制

    • 热门的KVCache块需要被复制以确保高可用性。
  5. 调度器目标

    • 保证低Cache负载和高Cache命中率。
  6. 高负载情况下的策略

    • 请求可能不会被直接发送给缓存最长前缀的实例,而是转发给备选实例。备选实例会主动从缓存持有者处检索KV缓存并存储本地。
    • 当最佳的远程前缀匹配长度不超过当前本地可重用前缀的阈值时,系统优先使用本地缓存,而不是从远程实例获取令牌。

这些策略不仅减少了请求的Prefill时间,还自动复制热点缓存,使其在多台机器上更广泛地分布。

拒绝策略

论文提到了一种基于预测的拒绝策略。Prefill和Decode的负载节奏是相反的,可能在Decode负载高时,Prefill负载较低。此时如果拒绝请求,会导致Decode负载下降,而Prefill完成后Decode负载又会升高,进而再次拒绝请求。引入预测拒绝策略后,可以使Prefill过程更加平滑,减少频繁拒绝请求的情况,从而减小负载节奏的波动。

参考Pytorch版本的llama-from-scratch。原文中的RmsNorm的平均值多算了一个维度,这里改成了正确的版本。

首先需要下载TinyShakespeare数据集,这是一个莎士比亚文字数据集。本文档将大致遵循论文的布局,并跳过一些明显的步骤,比如设置虚拟环境和安装依赖项。

我们最终将实现的内容预览:

1
2
3
4
5
6
7
8
9
10
11
12
13
14

println!(generate(llama, MASTER_CONFIG, 500, device)[0])

ZELBETH:
Sey solmenter! 'tis tonguerered if berryishdd, and What his stabe, you, and, but all I pilJefals, mode with,
Vurint as steolated have loven OlD the queen'd refore
Are been, good plmp:

Proforne, wift'es swleen, was no bunderes'd a a quain beath!
Tybell is my gateer stalk smen'd as be matious dazest brink thou
lord
Enves were cIUll, afe and whwas seath This a is, an tale hoice his his onety Meall-tearn not murkawn, fase bettizen'd her,
To belacquesterer? baxewed wupl usweggs yet tall
An

实现过程中可能涉及一些 Rust 的使用方法,与 Python 有所不同,这里不做过多说明,具体的 Rust 语法可以参考其他文档。

迭代工作:从小模块开始,保持确定性,然后逐步构建

  1. 创建所有需要的辅助函数,以便定量测试模型(数据拆分、训练、绘制损失)。
  2. 从论文中挑选出不同的组件,然后逐一实现,边训练边评估。

确保你的层按预期工作

  1. 经常使用 .shape()assert 是你的朋友。
  2. 先在不进行矩阵乘法的情况下计算结果,然后使用 candle 函数使其高效。
  3. 有一个测试来确保你的层是正确的。例如,RoPE 嵌入有一个特定的属性,你可以测试它。对于 Transformer,你可以通过查看注意力矩阵来测试注意力是否正常工作。
  4. 在各种批次、序列和嵌入大小上测试你的层。即使它适用于一种大小,它可能不适用于其他大小,这将在推理时导致问题。

关于 Llama

Llama 是一种基于 Transformer 的语言建模模型。它是一个自回归模型,也称为 CausalModel,模型会将输出中的 token 加入到输入中,不断迭代推理,直到超过上下文长度或遇到停止符。Meta AI 开源了 Llama,并明确表示他们的目标是使模型在推理时更高效,而不是优化训练成本。

接下来,我们将加载库并开始实现。

1
2
3
4
5
6
7
8
9
10
11
12
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::ops::softmax;
use candle_nn::{
embedding, linear, loss, AdamW, Embedding, Init, Linear, Module, Optimizer, ParamsAdamW,
VarBuilder,
};

use core::f32;
use rand::Rng;
use std::collections::HashMap;
use std::fs;
use std::time;

设置数据集

虽然 Llama 在 1.4T 个标记上进行训练,但我们的数据集 TinyShakespeare,即莎士比亚所有作品的集合,大约只有 100 万个字符。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
use std::collections::HashMap;
use std::fs;

fn main() {
// Read the entire content of the file
let lines = fs::read_to_string("./input.txt")
.expect("Failed to read the file");

// Create a sorted set of unique characters
let mut vocab: Vec<char> = lines.chars().collect();
vocab.sort_unstable();
vocab.dedup();

// Create itos and stoi mappings
let itos: HashMap<usize, char> = vocab.iter().enumerate().map(|(i, &ch)| (i, ch)).collect();
let stoi: HashMap<char, usize> = vocab.iter().enumerate().map(|(i, &ch)| (ch, i)).collect();

// Print the first 30 characters of the file
println!("{}", &lines[..30.min(lines.len())]);
}
First Citizen:
Before we proce

他们使用了SentencePiece字节对编码分词器,但我们将只使用一个简单的字符级分词器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use std::collections::HashMap;
use std::fs;
struct Vocab {
itos: HashMap<u32, char>,
stoi: HashMap<char, u32>,
}

impl Vocab {
fn new(itos: HashMap<u32, char>, stoi: HashMap<char, u32>) -> Vocab {
Vocab { itos, stoi }
}
fn decode(&self, ids: &[u32]) -> String {
ids.iter().map(|&id| self.itos[&id]).collect()
}
fn encode(&self, text: &str) -> Vec<u32> {
text.chars().map(|ch| self.stoi[&ch]).collect()
}
fn len(&self) -> usize {
self.itos.len()
}
fn build(lines: &str) -> Self {
// Create a sorted set of unique characters
let mut vocab: Vec<char> = lines.chars().collect();
vocab.sort();
vocab.dedup();

// Create itos and stoi mappings
let itos: HashMap<u32, char> = vocab
.iter()
.enumerate()
.map(|(i, &ch)| (i as u32, ch))
.collect();
let stoi: HashMap<char, u32> = vocab
.iter()
.enumerate()
.map(|(i, &ch)| (ch, i as u32))
.collect();
Self { itos, stoi }
}
}

fn main() {
// Read the entire content of the file
let lines = fs::read_to_string("./input.txt").expect("Failed to read the file");

let vocab = Vocab::build(&lines);
println!("vocab size = {}", vocab.len());
println!("{}", vocab.decode(&vocab.encode("hello")));
}
vocab size = 65
hello

由于数据集较小,我们无需担心内存存储问题。

我们创建了一个 config 对象来存储基本的模型参数。这样可以提高代码的可读性,并且便于修改配置。Rust 是强类型语言,因此所有变量都有明确的类型。

1
2
3
let mut modeConfig = ModelConfig {
vocab_size: vocab.len(),
}
1
2
let dataset = Tensor::from_slice(&vocab.encode(&lines), (lines.len(),), &Device::Cpu).unwrap();
println!("{:?}", dataset.shape());
[1115394]

让我们创建一个方法 get_batches 来生成训练数据和目标的批次。我们将使用相同的方法来生成验证和测试数据,通过 split 参数来控制。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
fn get_batches(
dataset: &Tensor,
split: &str,
batch_size: usize,
context_length: usize,
) -> Result<(Tensor, Tensor)> {
let len_of_dataset = dataset.shape().dim(0).unwrap() as f32;
// 按照 0.8 0.1 0.1 的比例切分训练集, 验证集和测试集
let batch_data = match split {
"val" => &dataset.i((0.8 * len_of_dataset) as usize..(0.9 * len_of_dataset) as usize)?,
"test" => &dataset.i((0.9 * len_of_dataset) as usize..)?,
_ => &dataset.i(..(0.8 * len_of_dataset) as usize)?,
};
// 生成随机index
let mut rng = rand::thread_rng();
let data_len = batch_data.shape().dim(0)?;
let indices: Vec<usize> = (0..batch_size)
.map(|_| rng.gen_range(0..data_len - context_length - 1))
.collect();
let mut x_batches = Vec::with_capacity(batch_size);
let mut y_batches = Vec::with_capacity(batch_size);

for idx in indices {
let x = batch_data.i(idx..(idx + context_length))?;
// y 是 x 后面的一个字符
let y = batch_data.i((idx + 1)..(idx + context_length + 1))?;
x_batches.push(x);
y_batches.push(y);
}
// stack 和 cat 的区别是, stack 是在新的维度上堆叠, cat 是在已有的维度上堆叠
let x_tensor = Tensor::stack(&x_batches, 0)?;
let y_tensor = Tensor::stack(&y_batches, 0)?;
Ok((x_tensor, y_tensor))
}
// in fn main
modeConfig.context_length = 16;
modeConfig.batch_size = 8;
let batch = get_batches(
&dataset,
"train",
modeConfig.batch_size,
modeConfig.context_length,
)?;
println!(
"batch size {}, context_length {}",
batch.0.shape().dim(0)?,
batch.0.shape().dim(1)?
);
for i in 0..modeConfig.batch_size {
println!(
"{:?}, {:?}",
vocab.decode(&batch.0.i(i)?.to_vec1()?),
vocab.decode(&batch.1.i(i)?.to_vec1()?),
);
}

":\nBut, that I'll", "\nBut, that I'll "
"ng?\nWhy, then th", "g?\nWhy, then the"
"s so blind, but ", " so blind, but s"
"thy offices,\nSo ", "hy offices,\nSo r"
"ords, how plainl", "rds, how plainly"
"IET:\nHere's such", "ET:\nHere's such "
"wer\nTo take off ", "er\nTo take off s"
" hurry from the ", "hurry from the f"

实现论文有趣的一点在于,模型“工作”有两个方面:编译(你的张量是否在各层之间匹配)和训练(损失是否下降)。
我们还要定义评估模型的方法。我们希望在定义模型之前就这样做,因为我们希望在训练模型时使用它来评估模型的性能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
fn evaluate_loss(
model: &SimpleBrokenModel,
dataset: &Tensor,
vocab: &Vocab,
config: ModelConfig,
) -> Result<HashMap<String, f32>> {
let mut out = HashMap::new();
for split in ["train", "val"] {
let mut losses = Vec::new();
for _ in 0..10 {
let (xs, ys) = get_batches(&dataset, split, config.batch_size, config.context_length)?;
let (_, loss) = model.forward(&xs, Some(&ys))?;
let loss = loss.unwrap();
losses.push(loss.to_scalar::<f32>()?);
}
let avg_loss = losses.iter().sum::<f32>() / losses.len() as f32;
out.insert(split.to_owned(), avg_loss);
}
Ok(out)
}

设置一个可以工作的简单模型

这是一个带有嵌入的基本前馈神经网络。它是我们将要开始的基础模型,然后我们将逐步替换其部分内容,直到最终得到 Llama 论文中描述的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
struct SimpleBrokenModel {
embedding: Embedding,
mlp: Sequential,
config: ModelConfig,
}

impl SimpleBrokenModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// 潜入层
let embeds = self.embedding.forward(x)?;

// 线性和激活层
let logits = self.mlp.forward(&embeds)?;

// 如果提供了targets就计算loss,不然视为推理,计算logits就可以。
if let Some(targets) = targets {
// 负的似然函数
// -log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// 当 -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}
// VarBuilder是用来构建参数的,我们目前不加载和保存模型参数,但是candle的用法必须基于这个。
// vb.pp 会在参数树中加入参数的前缀,这样可以方便的查看参数的结构。
fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
})
}
}
// in fn main
modeConfig.d_model = 128;
modeConfig.batch_size = 32;
let (xs, ys) = get_batches(
&dataset,
"train",
modeConfig.batch_size,
modeConfig.context_length,
)?;
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let model = SimpleBrokenModel::load(vb, modeConfig)?;
let (logits, loss) = model.forward(&xs, Some(&ys))?;
println!("{:?} {:?}", logits, loss);
let mut params_count: usize = 0;
for (name, var) in varmap.data().lock().unwrap().iter() {
println!("{}: {:?}", name, var.elem_count());
params_count += var.elem_count();
}
println!("params count: {}", params_count);
Tensor[dims 32, 16, 65; f32] Some(Tensor[5.266067; f32])
model.fc2.weight: 8320
model.embed_tokens.weight: 8320
model.fc1.bias: 128
model.fc2.bias: 65
model.fc1.weight: 16384
params count: 33217

在这一点上,我们必须开始关注张量的形状,并让矩阵的维度匹配。查看我们模型定义中的这一行:

1
2
3
4
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;

我们必须调整 logitstargets 张量的形状,以便在比较时它们的维度匹配。我们使用 reshape 方法来实现这一点。
() 参数的意思是“从其他维度推断这个维度”。所以,在这种情况下,我们是在说“将 logitstargets 重新调整为具有相同行数的形状,并使用所需的列数来实现这一点”。这是处理批量数据时的常见模式。

让我们训练我们的 SimpleBrokenModel 以确保梯度流动。在确认这一点之后,我们可以替换它的部分内容以匹配 Llama,再次训练并跟踪我们的进展。在这一点上,我开始记录我的训练运行日志,这样如果我搞砸了,我可以轻松地回到之前的运行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
fn train(
config: ModelConfig,
model: &SimpleBrokenModel,
opt: &mut AdamW,
dataset: &Tensor,
vocab: &Vocab,
) -> Result<()> {
let mut start_time = std::time::Instant::now();
for epoch in 0..config.epochs {
let (xs, ys) = get_batches(&dataset, "train", config.batch_size, config.context_length)?;
let (_, loss) = model.forward(&xs, Some(&ys))?;
opt.backward_step(&loss.unwrap())?;
if epoch % config.log_interval == 0 {
let batch_duration = start_time.elapsed().as_secs_f32();
let loss = evaluate_loss(&model, dataset, vocab, config)?;
let val_loss = loss.get("val").unwrap();
let eta = batch_duration * (config.epochs - epoch) as f32;
let eta = eta.round();
println!(
"Epoch: {epoch} | Loss: {val_loss} | Time: {batch_duration} | ETA in seconds {eta}"
);
start_time = time::Instant::now();
}
}
Ok(())
}
// in fn main
modeConfig.log_interval = 10;
modeConfig.epochs = 100;
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), ParamsAdamW::default())?;
train(modeConfig, &model, &mut opt, &dataset, &vocab)?;
let out = evaluate_loss(&model, &dataset, &vocab, modeConfig);
println!("{:?}", out);
Epoch: 10 | Loss: 3.9159875 | Time: 6.5813394 | ETA in seconds 599
Epoch: 20 | Loss: 3.26492 | Time: 6.3639965 | ETA in seconds 515
Epoch: 30 | Loss: 2.9944448 | Time: 6.3596206 | ETA in seconds 452
Epoch: 40 | Loss: 2.8793342 | Time: 6.357106 | ETA in seconds 388
Epoch: 50 | Loss: 2.7827232 | Time: 6.3562865 | ETA in seconds 324
Epoch: 60 | Loss: 2.764416 | Time: 6.352279 | ETA in seconds 260
Epoch: 70 | Loss: 2.7196321 | Time: 6.356127 | ETA in seconds 197
Epoch: 80 | Loss: 2.7631993 | Time: 6.357493 | ETA in seconds 134
Epoch: 90 | Loss: 2.696882 | Time: 6.358631 | ETA in seconds 70
Epoch: 100 | Loss: 2.670012 | Time: 6.3603354 | ETA in seconds 6
Ok({"train": 2.591057, "val": 2.6625311})
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
fn generate(model: &SimpleBrokenModel, vocab: &Vocab, max_tokens: usize) -> Result<()> {
// batch size 5, initial token = 0
let mut token_ids = Tensor::zeros((5, 1), DType::U32, &Device::Cpu).unwrap();
for _ in 0..max_tokens {
let (logits, _) = model.forward(&token_ids, None)?;
assert!(logits.shape().dims() == [token_ids.dim(0)?, token_ids.dim(1)?, 65]);
let last_step_logits = logits.i((.., logits.dim(1)? - 1))?;
assert!(last_step_logits.shape().dims() == [token_ids.dim(0)?, 65]);
let probs = softmax(&last_step_logits, last_step_logits.dims().len() - 1)?;
assert!(probs.shape().dims() == [token_ids.dim(0)?, 65]);
let next_token = probs.argmax(probs.dims().len() - 1)?;
assert!(next_token.shape().dims() == [token_ids.dim(0)?]);
token_ids = Tensor::cat(&[token_ids, next_token.reshape(((), 1))?], 1)?;
}
let lines = fs::read_to_string("./input.txt").expect("Failed to read the file");
for v in &token_ids.to_vec2()? {
let text = vocab.decode(v);
println!("{}", text);
}
Ok(())
}
// fn in main
generate(&model, &vocab, 10, device)?;
['\nFind!\nD:\nAr t,\nLis sthte o t l',
 '\nAnd ronnot ar\nBE:\nKINRDYOrspr;',
 '\nI t athe momyengthend thanswal',
 '\nFis t bp he\nLacarn.\nA:\nYOMI wi',
 '\nWh ly sck\nB-de pll t\nHERIns ou']

这还算不错,但也不算太好。不过现在我们有了一个可以训练到验证损失的工作模型。因此,我们将在此基础上迭代我们的模型,使其更接近 Llama。

Llama 具体细节

Llama 对原始 Transformer 进行了三项架构修改:

  1. 用于预归一化的 RMSNorm
  2. 旋转嵌入 RoPE
  3. SwiGLU 激活函数

我们将逐一添加每个修改到我们的基础模型,并进行迭代。

RMSNorm

在 Vaswani 2017 中,原始的 Transformer 使用了 BatchNormalization。在 Llama 中,作者使用了 RMSNorm,这是一种在不进行中心化的情况下通过方差来缩放向量的方法。此外,虽然 Vaswani 将归一化应用于注意力层的输出(后归一化),但 Llama 将其应用于输入之前(前归一化)。

这篇文章对于RMSNorm有一个很好的解释。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
pub struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(size: usize, vb: VarBuilder) -> Result<Self> {
Ok(RmsNorm {
scale: vb.get_with_hints(size, "weight", Init::Const(1.))?,
eps: 1e-6,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_sqr = x.sqr()?;
assert!(x_sqr.shape().dims() == x.shape().dims());
let norm_x = (x.mean(D::Minus1)? + self.eps)?.sqrt()?;
assert!(norm_x.shape().dims() == [x.shape().dim(0)?, x.shape().dim(1)?]);
let x_normed = x.broadcast_div(&norm_x.reshape((
norm_x.shape().dim(0)?,
norm_x.shape().dim(1)?,
(),
))?)?;
assert!(x_normed.shape().dims() == x.shape().dims());
let x = (x_normed.broadcast_mul(&self.scale))?;
Ok(x)
}
}

let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);

let rms_norms = RmsNorm::new(2, vb)?;
// (2,3,2)
let batch = Tensor::new(
vec![
vec![vec![1f32, 1f32], vec![1.2f32, 2f32], vec![3f32, 3f32]],
vec![vec![4f32, 43f32], vec![5f32, 5f32], vec![61f32, 6f32]],
],
&Device::Cpu,
)?;
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let out = rms_norms.forward(&batch)?;
Tensor[dims 2, 3, 2; f32]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
struct SimpleBrokenModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
}

impl SimpleBrokenModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// Embedding
let embeds = self.embedding.forward(x)?;
// RMSNorm
let normed_embeds = self.rms_norm.forward(&embeds)?;
// Linear layers and activation
let logits = self.mlp.forward(&normed_embeds)?;

// Calculate loss if targets are provided
if let Some(targets) = targets {
// 负的似然函数
// log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
// println!("{:?}", targets.shape());
// println!("{:?}", logits.shape());
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}
// VarBuilder是用来构建参数的。
fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
})
}
}

Epoch: 10 | Loss: 4.1559505 | Time: 6.779387 | ETA in seconds 617
Epoch: 20 | Loss: 4.14648 | Time: 6.7727704 | ETA in seconds 549
Epoch: 30 | Loss: 4.1364665 | Time: 6.776428 | ETA in seconds 481
Epoch: 40 | Loss: 4.125594 | Time: 6.772582 | ETA in seconds 413
Epoch: 50 | Loss: 4.120083 | Time: 6.7661977 | ETA in seconds 345
Epoch: 60 | Loss: 4.1099877 | Time: 6.760399 | ETA in seconds 277
Epoch: 70 | Loss: 4.0996284 | Time: 6.7623253 | ETA in seconds 210
Epoch: 80 | Loss: 4.0902996 | Time: 6.761824 | ETA in seconds 142
Epoch: 90 | Loss: 4.0833025 | Time: 6.76845 | ETA in seconds 74
Epoch: 100 | Loss: 4.070025 | Time: 6.7624454 | ETA in seconds 7
Ok({"train": 4.072861, "val": 4.0711236})

从这里得到的结果来看,范化以后,模型的表现并没有提升,所以我们需要继续迭代,只是梯度的下降变得比较平滑了。

Rotary Embeddings

RoPE 是一种用于 Transformer 的位置编码方法。在《Attention is All You Need》中,作者提出了两种位置编码方法:学习的和固定的。在 RoPE 中,作者通过旋转嵌入来表示序列中标记的位置,并在每个位置使用不同的旋转角度。

其中的 cos 和 sin 值可以预先计算并缓存,避免重复计算,后续会统一存放在一个缓存结构中。

RoPE 将 hidden_state 中每两个 x 组成的向量与旋转矩阵相乘来实现位置编码。

1
2
3
4
[x0, x1, .... ,xn]
y0 = x0 * cos(theta) - x1 * sin(theta)
y1 = x0 * sin(theta) + x1 * cos(theta)
[y0, y1, ...., yn]

nd_model 的一半。
theta 是一个根据位置得到的固定值,计算公式为:
theta = m / 10000^(2i/n)
其中,m 是在序列中的位置,i 是在 d_model 中的位置。

这个公式的含义是将特征向量中的 x0x1 进行一个固定的旋转,这个旋转不是通过学习得到的,而是预先计算的。它可以用于表示相对位置信息。

1
2
3
```Rust
pos_index = 0, 中的 x0,x1
pos_index = 1, 中的 x0,x1

隔了一个恒定的调度旋转。

freq_cis缓存住提前算好的cos和sin的值,这部分不用重复计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct Cache {
cos: Tensor,
sin: Tensor,
}

impl Cache {
fn new(context_length: usize, n_elem: usize, vb: VarBuilder) -> Result<Cache> {
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), vb.device())?;
let idx_theta = Tensor::arange(0, context_length as u32, vb.device())?
.to_dtype(DType::F32)?
.reshape((context_length, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let freq_cis_real = idx_theta.cos()?;
let freq_cis_imag = idx_theta.sin()?;

let cos = freq_cis_real.reshape((context_length, n_elem / 2, 1))?;
let sin = freq_cis_imag.reshape((context_length, n_elem / 2, 1))?;
Ok(Cache { cos, sin})
}
}

rope计算的时候

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
fn apply_rotary_emb(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
// println!("shape of cache.cos {:?}", cache.cos.shape());
let cos = cache.cos.i(..seq_len)?;
let sin = cache.sin.i(..seq_len)?;
// println!("cos shape {:?}", cos.shape());
let cos = cos.broadcast_as((b_sz, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, n_embd / 2, 1))?;
// println!("broadcast cos shape {:?}", cos.shape());
let x = x.reshape((b_sz, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, n_embd))?;
Ok(rope)
}

Self Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
struct AttentionModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
self_attention: SelfAttention,
cache: Cache,
}

impl AttentionModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// 嵌入层
let embeds = self.embedding.forward(x)?;
// 范化层
let normed_embeds = self.rms_norm.forward(&embeds)?;
// 自注意力层
let y = self.self_attention.forward(&normed_embeds, &self.cache)?;
// 线性和激活层
let logits = self.mlp.forward(&normed_embeds)?;

if let Some(targets) = targets {
// 负的似然函数
// -log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// 例如 -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let self_attention = SelfAttention::load(vb.pp("model.self_attention"), config)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
self_attention,
cache: Cache::new(config.context_length, config.d_model, vb)?,
})
}
}

提示:了解训练时张量维度与推理时张量维度的区别。

虽然在训练时,你可以期望张量维度与模型参数紧密匹配,例如 batch.shape = (config['batch_size'], config['context_window'], config['d_model']),但在推理时,你可能需要处理单个示例,例如 batch.shape = (1, 1, config['d_model'])。因此,你需要确保在 forward 传递中进行索引时,使用从输入派生的形状,而不一定是模型参数。

MultiHeadRopeAttention

让我们为这个单一的注意力头设置一个多头注意力层,看看训练时会发生什么。

这里实现的是GQA的注意力头。n_kv_head=1时就是MQA,n_kv_head>1n_kv_head<n_head时就是GQA,n_kv_head=n_head时就是原本的MHA。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
struct MultiHeadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
}

impl MultiHeadAttention {
fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let q_proj = linear(config.d_model, config.d_model, vb.pp("model.q_proj"))?;
let k_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.k_proj"),
)?;
let v_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.v_proj"),
)?;
let o_proj = linear(config.d_model, config.d_model, vb.pp("model.o_proj"))?;
println!(
"MHA config n_head {} n_kv_head {}",
config.n_head, config.n_kv_head
);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: config.n_head,
n_kv_head: config.n_kv_head,
head_dim: config.d_model / config.n_head,
})
}
fn apply_rotary_emb(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = cache.cos.i(..seq_len)?;
let sin = cache.sin.i(..seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
Ok(rope)
}
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;

// 计算 q k v
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;

assert!(n_embd == self.n_head * self.head_dim);

let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;
let v = v.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;

// 对 q 和 k 做位置编码
let q = self.apply_rotary_emb(&q, cache)?;
let k = self.apply_rotary_emb(&k, cache)?;
// 复制成 n_head / n_kv_head 份
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;

// 把 seq_len 和 n_head 交换
// 这转换一下是为了做一个 cat single head 的简单操作
// 相当于 n_head 个的seq_len*seq_len的注意力。
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;

// q*k^T / sqrt(d_k) d_k = d_model
// 这里是 (bs,n_head) 个 (seq_len, seq_len) 的注意力
let attn = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;

// 这里是头内的softmax (seq_len,seq_len)的行总和为1
let attn = softmax(&attn, D::Minus1)?;

// 再乘 * (bs, n_head, seq_len, head_dim) 得到 (bs, n_head)个注意力头对应的加权的v
let y = attn.matmul(&v)?;
// 把 n_head 和 seq_len 交换回来,得到 (bs, seq_len, n_head, head_dim) 然后reshape以后
// 得到 (bs, seq_len, n_head * head_dim) 把头给cat到一起。
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}

fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
}

完整模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
struct AttentionModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
self_attention: MultiHeadAttention,
cache: Cache,
}

impl AttentionModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {

let embeds = self.embedding.forward(x)?;

let normed_embeds = self.rms_norm.forward(&embeds)?;
let y = self.self_attention.forward(&normed_embeds, &self.cache)?;

let logits = self.mlp.forward(&y)?;

if let Some(targets) = targets {
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let self_attention = MultiHeadAttention::load(vb.pp("model.multi_head_attention"), config)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
self_attention,
cache: Cache::new(config.context_length, config.d_model/config.n_head, vb)?,
})
}
}

1
generate(&model, &vocab, 10, device)?;
['\n\n\n\n\n\n\n\nI\n\nOOOOOOOOOFOOtOOOOOOO',
 '\nIIIIII IIIIIIIIIIIIIIIIIIIIIII',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\naaame',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n']

所以看起来很糟糕。这里发生了什么?让我们通过查看注意力来开始调试。

目前的注意力是没有masked的,任何位置的字符都在关注任何其他位置的字符。
这有什么不好呢?我们试图仅基于之前的标记来预测下一个标记,但这里我们看到模型正在关注之后的标记。
换句话说,模型在作弊,或者从未来泄露信息。这是一个问题,这就是为什么我们需要使用因果掩码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

pub struct Cache {
// ...
// mask 也可以 cached
mask: Tensor,
}
impl Cache {
fn new(context_length: usize, n_elem: usize, vb: VarBuilder) -> Result<Cache> {
// ...
// _ 表示类型由编译器推断,
// 默认的collect 是 [_],但是这个大小是不可变的要编译期间决定
// 所以这里还是要提示编译器要 collect 成 vec.
let mask: Vec<_> = (0..context_length)
.flat_map(|i| (0..context_length).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (context_length, context_length), vb.device())?;
Ok(Cache { cos, sin, mask })
}
}

struct MultiHeadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
}

impl MultiHeadAttention {
fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let q_proj = linear(config.d_model, config.d_model, vb.pp("model.q_proj"))?;
let k_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.k_proj"),
)?;
let v_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.v_proj"),
)?;
let o_proj = linear(config.d_model, config.d_model, vb.pp("model.o_proj"))?;
println!(
"MHA config n_head {} n_kv_head {}",
config.n_head, config.n_kv_head
);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: config.n_head,
n_kv_head: config.n_kv_head,
head_dim: config.d_model / config.n_head,
})
}
// ...
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;

// 计算 q k v
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;

assert!(n_embd == self.n_head * self.head_dim);

let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;
let v = v.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;

// 对 q 和 k 做位置编码
let q = self.apply_rotary_emb(&q, cache)?;
let k = self.apply_rotary_emb(&k, cache)?;
// 复制成 n_head / n_kv_head 份
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;

// println!("q.shape {:?}", q.shape());
// 把 seq_len 和 n_head 交换
// 这转换一下是为了做一个 cat single head 的简单操作
// 相当于 n_head 个的seq_len*seq_len的注意力。
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;

// let tmp = q.matmul(&k.t()?)?;
// 这个结果是有负数的,但是注意力层不会有负数。
// 计算结果出了很多 NaN,感觉应该是范化没做好。
// 后面发现是没有用 sqrt 用了开方,导致负数变成了NaN。
// q*k^T / sqrt(d_k) d_k = d_model
// 这里是 (bs,n_head) 个 (seq_len, seq_len) 的注意力
let attn = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
// 在 softmax 之前,把未来的token位置变为负无穷,这样在softmax之后,这些位置的概率就会变为0
let mask = cache
.mask
.i((..seq_len, ..seq_len))?
.unsqueeze(0)?
.unsqueeze(0)?
.broadcast_as(attn.shape())?;
let on_true =
Tensor::new(f32::NEG_INFINITY, attn.device())?.broadcast_as(mask.shape().dims())?;
let attn = mask.where_cond(&on_true, &attn)?;
// 取一个例子
// 这里是头内的softmax (seq_len,seq_len)的每行 (seq_len) 总和为1
let attn = softmax(&attn, D::Minus1)?;

// 再乘 * (bs, n_head, seq_len, head_dim) 得到 (bs, n_head)个注意力头对应的加权的v
let y = attn.matmul(&v)?;
// 把 n_head 和 seq_len 交换回来,得到 (bs, seq_len, n_head, head_dim) 然后reshape以后
// 得到 (bs, seq_len, n_head * head_dim) 把头给cat到一起。
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}

fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
}

现在,我们可以让注意力激活的上三角部分(对应未来的部分)几乎被归零了。让我们看看训练时会发生什么。

SwiGLU

正如论文中所述,“我们用SwiGLU激活函数替换了ReLU非线性函数……我们使用$\frac{2}{3} 4d$的维度,而不是PaLM中的$4d$。” SwiGLU定义为:
$$
\text{SwiGLU}(x) = \text{Swish}_\beta (xW + b) \otimes (xV + c)
$$
其中$\otimes$是逐元素乘积。Swish函数定义为:
$$
\text{Swish}_\beta(x) = x \sigma(\beta x)
$$
其中$\beta$是一个可学习的参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}

struct SwiGLU {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
}
// 新的 mlp 是三层的带gate
impl SwiGLU {
// silu 的特征是允许有一点点的负数
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// 这里就是 SwiGLU SiLU(W_1 * x) * (W_2 * x) 是 element wise 的,这个可以作为gate门信号
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}

fn load(vb: VarBuilder, cfg: ModelConfig) -> Result<Self> {
let h_size = cfg.d_model;
let i_size = cfg.hidden_dim;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
Ok(Self {
c_fc1,
c_fc2,
c_proj,
})
}
}

一个llama block res 两次,一次在attention之前,一次在swiglu之前。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
struct Block {
rms_1: RmsNorm,
attn: MultiHeadAttention,
rms_2: RmsNorm,
swiglu: SwiGLU,
}

impl Block {
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, cache)? + residual)?;
let residual = &x;
let x = (self.swiglu.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}

fn load(vb: VarBuilder, cfg: ModelConfig) -> Result<Self> {
let attn = MultiHeadAttention::load(vb.pp("self_attn"), cfg)?;
let swiglu = SwiGLU::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::new(cfg.d_model, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::new(cfg.d_model, vb.pp("post_attention_layernorm"))?;
Ok(Self {
rms_1,
attn,
rms_2,
swiglu,
})
}
}

现在,让我们通过创建块来添加多个层的最后完整的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
struct Llama {
embedding: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
cache: Cache,
config: ModelConfig,
}

impl Llama {
pub fn forward(
&self,
x: &Tensor,
targets: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.embedding.forward(x)?;
for block in &self.blocks {
x = block.forward(&x, &self.cache)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;

if let Some(targets) = targets {
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

pub fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let embed_layer = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let lm_head = linear(config.d_model, config.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(config.d_model, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..config.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), config).unwrap())
.collect();
Ok(Self {
embedding: embed_layer,
blocks,
ln_f,
lm_head,
config,
cache: Cache::new(config.context_length, config.d_model / config.n_head, vb)?,
})
}
}

扩展

这里主要是训练的部分,在推理的过程中还涉及到一个比较重要的kv cache。
kv cache主要是缓存自回归过程中的kv,这个kv是不变的,因为这个k 和 v 只和之前的token有关系,所以可以缓存下来,
这样在推理的时候就不用重复计算了。
围绕着prompt产生第一个Token的prefill的阶段和计算完prompt之后的decode阶段也是目前业界比较关注的推理优化的方向。
源代码在这里

这篇文章主要基于vLLM中做推理的优化做一个总结。

PagedAttention

原始论文来看,显存的浪费主要有几种。

这张图里面表示的是每个token对应的kvcache的slot。对于一个context长度中用不到的slots部分,有预留的slots和不同空隙之间的slots空隙。

vLLM参考了虚拟内存和内存页分配的逻辑构造了一个block table用于 kv cache block slots 和 token之间的关系。通过block表的管理
可以像操作系统一下减少内存碎片。
因为不同的sequence中token是有位置信息的,所有他们对应的kv slot也不一定一样。下图展示了他们的关系。

Continuous Batching

“连续批量处理”(Continuous Batching),也称为”动态批量处理”(Dynamic Batching)或”迭代级调度批量处理”(Batching with Iteration-Level Scheduling),是一种选择Batch行的技术,用于优化计算和资源利用率。
它的主要区别在于与静态批处理相比,它会在每次推理迭代过程中动态调整Batch中的序列,例如vLLM会抢占部分序列,将序列的prefill阶段和decode阶段分开,不在同一个batch中处理。

而HuggingFace的text-generation-inference的router文档中提到的:为了提高效率,特别是在文本生成和内存受限的LLM上,可以让客户端发送单个查询,然后路由器将这些查询合并或分离成批次,以最大限度地利用计算资源。这是因为在LLM中,运行模型的计算成本远远高于批处理的成本。当新请求到达时,当前正在forward的前向传播不中断,而是继续等待执行完毕。然后将当前正在处理的请求与新到的请求合并成一个批处理请求,再进行forward前向传播。在批处理请求中,任何一个请求完成(即模型产生了终止符或达到允许的最大长度),则从批处理请求中移除该请求,并释放相关资源。这种方法可以应用于多个请求,并且支持在不同参数下的处理(例如采样、不采样、温度控制等),因为每个请求在批处理中都可以独立进行处理。Anyscale 对这个过程有很好的解释

Prefill 和 Decode

在LLM推理过程中,一个Prompt的第一次执行(称为prefill)和后续的前向传播(称为decode)是不同的。Prefill阶段需要计算整个注意力矩阵并将其缓存到KV缓存中,计算规模较大,尤其是对于长度为10K或100K的提示词。而在decode阶段,只需计算新生成的token的注意力矩阵,计算规模较小。

从Kimi的Mooncake论文中的图片来看:

左图显示,当prompt长度增加时,计算时间呈平方级别增加。右图显示,decode阶段只生成一个token,计算规模比线性增长还要慢一点,但由于需要复用之前的KV缓存,因此显存开销较大。Mooncake的解决方案是将prefill和decode分开处理,prefill计算规模大,decode计算规模小,通过KV缓存共享机制传递KV缓存。论文还提到了一些关于缓存调度的细节,这里不再展开。

在Prefill阶段,所有的提示词(prompt)都是已知的,因此可以并行计算多个token,计算并行度较高。而在Decode阶段,只需根据新生成的token计算下一个token(之前的KV缓存已经保存了中间结果),因此计算规模较小。如果将Prefill和Decode放在同一个batch中计算,由于计算规模不对等,容易产生计算空隙(bubble)。

Prefill像是一个矩阵和矩阵的乘法,而Decode则是一个向量和矩阵的乘法。

Chunked Prefill

vLLM的文档很好的解释了prefill阶段和decode阶段的区别。

考虑到vLLM进行生成的序列“ABC”。当它到达时,KV缓存基于block size的预设值(这里是2)在内存中分配对应的block(B1,B2,B3,B4),但它是空的。
我们知道序列的内容(A,B,C),但我们没有token id到块索引的映射。

考虑到这种情况,下一步是为序列ABC进行prefill。在调度过程中,我们为序列中每个token块分配块索引(B3,B4),即([A, B], [C, _])。

一旦确定了块映射,我们就可以通过运行模型的前向传播。这会将ABC token的KV激活值写入KV缓存中的相应位置。此外,前向传播将导致新token “D” 被采样。D的KV值尚未知晓。

现在,序列已经完成了预填充。我们可以安排一个解码步骤。这涉及为新token “D” 分配块映射。然而,由于它适合现有的块映射,调度器不需要分配新的映射。

然后我们再次运行模型,计算并将“D”的KV写入KV缓存。这会生成一个新的token “E”。

这个过程会重复进行直到解码完成。请注意,后续的分配可能是不连续的。

Speculative Decoding

Speculative Decoding利用小型、快速的草案(draft)模型生成初始token(token是输入信息的基本单元)时的高效性,而在验证阶段则依赖更大的、更准确的大型语言模型(LLM)进行验证。

这个过程可以分为两步:

  1. 初始token生成:使用小型、快速的草案模型生成初始token。这一步骤快速生成token序列,使得下一步骤可以快速开始。
  2. 验证:使用更大的、更准确的LLM对初始token进行验证。这一步骤确保生成的token序列是准确的和有效的。

这种技术通过将任务分成两个阶段来实现高效性和准确性:快速生成初始token,然后验证这些token以确保准确性。

其实现方式如下:

  • 使用小模型进行多次decode,生成多个token序列。
  • 将这些token序列传递给大模型进行验证。大模型会生成对应的logits,形状为(batch, sequence, vocab_size),并进行softmax处理。
  • 对于单步decode,logits的形状为(batch, 1, vocab_size),其中的1表示最后一个token,在词汇表(vocab_size)上的概率分布。

例如,将小模型生成的”abcd” token序列传递给大模型,得到”ABCD”的logits。A对应的是用于预测B的概率分布,在这个序列中就是预测第二个token的概率分布。将b对应的token id在A的vocab_size长度的log prob中的值取出,对应的可能是B token id的概率,也可能不是。

1
2
3
a | b | c | d
^ ^ ^
| A | B | C | D

最后,将b、c、d在大模型中对应的logits prob求和并取平均值。如果这个值大于一个阈值,就认为这个token是合理的,否则就拒绝。

这种方法通过结合小模型的快速生成能力和大模型的高准确性,实现了高效且准确的token生成过程。

根据论文[https://arxiv.org/pdf/2406.14066]来说,在continus batching的情况下,Speculative Decoding可以提高推理速度,减少计算资源的浪费。

在Target模型验证完以后,还会多生成一个bonus token,也就是上面的那个D。

其中Decoding也产生了很多的方法,有基于模型的,也有model free的,才用大模型的一部分,或者直接从外部数据库来获取。

结果表明,在低请求率下(具体来说,请求率为 4),提出 3 个或 5 个令牌会带来最显著的加速。然而,随着请求率的增加,提出更多令牌的优势迅速减弱:当请求率超过 12 时,提出 5 个令牌不再带来性能提升。同样,在请求率大于 16 时,提出 3 个令牌会导致性能下降。

其中不同的颗粒度的猜测长度也会对性能有影响。

全局统一的长度;每个step所有request用一个长度;每个step每个request用不同的长度。

相较于吞吐,Goodput规定只有没被拒绝的token才计算,用来衡量最总的性能。

这张图展示了猜测长度和batch size对于Goodput的影响。

对于小批次,要多猜测(propose),小批次尺寸下每次请求需提议超 4 个 token 以实现最大有效吞吐量(goodput),且随着批次尺寸增大,最优猜测长度会降低;
对于大批次,则要少猜测,甚至不进行推测反而能获得更高有效吞吐量,因为大批次下推测失败成本显著增加,超过潜在收益。

除了朴素的猜测模型,里面也提到了Medusa风格的猜测模型。预测3个token就有三个head分别预测每个位置。
里面的例子head 1猜了三个,对了1个,head 2猜了2个对了一个,head3猜了3个全错了,然后加上LLM的bonus token。

SmartSpec 估算Goodput的方法就是根据成功率计算期望长度再乘上Request对应的时间。
然后根据不同的batch size计算Goodput,让Goodput最大化,得到最佳Goodput从而让吞吐最大化。

Automatic Prefix Caching

一般的LLM请求的提示词会非常长,据说OpenAI的系统提示词已经有几K了,这个规模很适合做前缀缓存。
前缀缓存是指将提示词分成多个前缀,然后将这些前缀缓存到KV缓存中,这样在生成token的时候就可以直接使用KV缓存中的值,
而不需要重新计算。这样可以减少计算量,提高推理速度。
当然在vLLM中kv cache是分块的,所以prefix 也是分块的。

1
2
3
4
5
                    Block 1                  Block 2                  Block 3
[A gentle breeze stirred] [the leaves as children] [laughed in the distance]
Block 1: |<--- block tokens ---->|
Block 2: |<------- prefix ------>| |<--- block tokens --->|
Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->|

Multi-LoRA Serving

VLLM当中的LoRA Adaptor是可以动态加载的,因为他本身和基座模型保持独立。
这里要注意的一个点是,如果词汇表修改了,会影响最后的llm head,比如英文基座模型用中文词汇表,那么 vocab_size 就不一样了。
会导致llm head的输出维度不一样,这个时候就需要重新训练llm head。
所以要注意最后的llm head的输出维度。

Tensor Parallelism

张量并行(Tensor Parallelism)是一种将大型模型的计算任务分解到多个GPU上并行执行的技术。它通过将模型的权重矩阵切分成多个子矩阵,并将这些子矩阵分配到不同的GPU上进行计算,从而实现并行计算。

在Transformer模型中,QKV(Query, Key, Value)矩阵乘法是计算量最大的部分之一。通过将QKV矩阵切分成更小的子矩阵,并将这些子矩阵分配到不同的GPU上,可以显著提高计算效率。

例如,对于一个具有d_model维度的QKV矩阵,可以将其切分成n个子矩阵,每个子矩阵的维度为d_model / n。然后,将这些子矩阵分配到n个GPU上进行并行计算。这样,每个GPU只需要计算一个较小的子矩阵,从而减少了计算时间。

张量并行的实现方式如下:

  1. 切分权重矩阵:将模型的权重矩阵切分成多个子矩阵。例如,对于一个d_model x d_model的权重矩阵,可以将其切分成nd_model x (d_model / n)的子矩阵。
  2. 分配子矩阵:将切分后的子矩阵分配到不同的GPU上。例如,将第一个子矩阵分配到GPU 0,第二个子矩阵分配到GPU 1,以此类推。
  3. 并行计算:在每个GPU上并行执行矩阵乘法计算。例如,在GPU 0上计算输入矩阵与第一个子矩阵的乘积,在GPU 1上计算输入矩阵与第二个子矩阵的乘积,以此类推。
  4. 聚合结果:将所有GPU上的计算结果聚合起来,得到最终的输出。例如,将所有子矩阵的乘积结果相加,得到最终的QKV矩阵乘积结果。

通过张量并行,可以显著提高大型模型的计算效率,减少推理时间,从而提高模型的推理速度和性能。

Pipeline Parallelism

Pipeline Parallelism通过将模型的不同层分配到不同的GPU上来实现并行计算,特别是一些大型模型。这种方法可以提高计算效率,但也会引入一些挑战,例如在层之间传递数据时可能会产生延迟。此外,由于不同层的计算时间可能不一致,可能会导致某些GPU在等待其他GPU完成计算时处于空闲状态,从而产生计算空隙(bubble)。

为了减少这些计算空隙,可以使用Chunked Prefill技术。Chunked Prefill通过将计算任务分成更小的块,从而缩短每个计算任务的时间。这使得在流水线并行中可以更灵活地安排计算任务,从而尽可能地填满计算空隙,提高整体计算效率。

通过结合Pipeline Parallelism和Chunked Prefill,可以在保持高效计算的同时,最大限度地利用计算资源,减少计算空隙,提高模型推理的速度和效率。

在vLLM当中 TP=n 且 PP=m 时,vLLM 引擎总共会有 n*m + 1 个进程。即使使用单个 GPU,我们也会有 2 个进程。

衡量LLM的服务指标

在衡量LLM服务性能时,Token相关的数据是一个重要的方面。以下是一些常见的Token相关指标:

Token Throughput

Token Throughput表示每秒生成的token数量。它是衡量模型生成速度的一个重要指标,通常以tokens per minute (TPM)为单位。

Token Latency

Token Latency表示生成一个token所需的时间。它是衡量模型响应速度的一个重要指标,通常以毫秒(ms)为单位。Token Latency包括以下几个子指标:

  • **TTFT (Time To First Token)**:从请求到第一个token生成的时间。例如,当prompt变长时,TTFT会变长;或者当kv cache不足时被抢占,TTFT也会变长。
  • **TBT (Time Between Tokens)**:生成两个token之间的时间。例如,当batch size变大时,TBT会变大。

vLLM的一个主要贡献就是PagedAttention,可以实现更高效的推理。

高效的语言模型服务系统(LLM)需要批量处理多个请求。然而,现有系统存在以下问题:

  • 每个请求的key-value缓存(KV缓存)内存巨大,动态增长和减少。
  • 容易因为碎片化和冗余复制导致内存浪费,限制了批量大小。

为了解决这些问题,提出了PagedAttention,一个基于虚拟内存和分页技术的注意力算法。基于此,开发了vLLM,一个LLM服务系统,实现了以下两个目标:

  1. KV缓存显存的几乎零浪费,减少了显存碎片。
  2. KV缓存在请求之间和请求内共享,进一步减少显存使用。

论文包含了他早期设计。

一次调用的示例如博客中展示的。

AsyncLLM

generate细节:

  • 如果引擎没有运行,启动后台循环,循环调用 _run_output_handler 方法来处理等待的请求。
  • AsyncStream 中等待请求输出并生成它们。

engine会在启动之前profile一下,把剩余的显存分配给kv cache用。

AsyncStream 对 asyncio.Queue的封装,支持了终止的能力,当finish的时候会丢入一个STOP_ITERATION的exception,这样可以让调用者知道这个stream已经结束了。

每当有一个对话请求的时候调用add_request就会生成一个这样的AsycStream用于处理对话的输出,其中副作用就是判断backgroud loop没有启动的时候,启动backgroundloop。

AsyncEngine本身有一个_new_request的Queue用户保存request的AsyncStream。

generate方法会不断从AsyncStream中yield出结果,直到遇到STOP_ITERATION。

loop的主体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 1) Pull EngineCoreOutput from the EngineCore.
outputs = await self.engine_core.get_output_async()

# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)

# 3) Put the RequestOutputs into the per-request AsyncStreams.
self._process_request_outputs(request_outputs)

# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)

# 5) Abort any requests due to client cancellations.
await self._process_cancellations()

When TP=n & PP=m, vLLM engine will have n*m + 1 processes in total.
Corollary: even when using a single GPU, we will have 2 processes.

EngineCore

EngineCore主要是完成 schedule、execute 和 output 的循环。

1
2
3
4
5
6
7
8
9
10
11
def step(self) -> List[EngineCoreOutput]:
"""Schedule, execute, and make output."""

if not self.scheduler.has_unfinished_requests():
return []

scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output)
return engine_core_outputs

Request

在具体分析之前,先看看 Request 的定义,这个数据结构串联了很多东西。

属性 num_tokens 代表的是 prompt_tokensoutput_tokens 的总数。

1
2
3
@property
def num_tokens(self) -> int:
return len(self._all_token_ids)

num_output_tokens 代表 output tokens 的数量。

1
2
3
@property
def num_output_tokens(self) -> int:
return len(self._output_token_ids)

append_output_token_ids 会改变上述的两个属性。

1
2
3
4
5
6
7
8
def append_output_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)

__init__ 方法中,会设置 num_prompt_tokens,这个是不变的,num_computed_tokens 会初始化为 0。

1
2
3
4
self.prompt = self.inputs.prompt
self.prompt_token_ids = self.inputs.prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self.num_computed_tokens = 0

所以在最开始时,num_tokensnum_prompt_tokens 是相等的。当 prefill 以后,num_computed_tokens 会逐渐(逐渐的原因是 prefill 可能会被 chunked 掉)等于 num_prompt_tokensdecode 以后,num_tokens 会等于 num_prompt_tokens 加上 num_output_tokens。如果 computed_tokens 等于 num_tokens,说明已经开始 decode 了,要开始一个 token 一个 token 计算了。

在调度过程中没有直接用 computed_tokens 等于 num_prompt_tokens 的原因是:如果一个 request 被抢占掉,那么 num_tokens 在 request 恢复的时候其实应该是 num_prompt_tokens 加上 num_output_tokens,这里做了一个统一的判断。如果把preempted的request重新处理的话其实相当于多了一些output tokens的prompt的新request。

Scheduler

从 EngineCore 的 step 方法来看,目前的调度是同步的 schedule | execute model | update_from_output | schedule | execute model | update_from_output,这样会导致计算和调度之间的时间差,这个时间差会导致计算的时间没有充分利用,从而导致资源的浪费。后面的版本应该会有优化。

Scheduler 的 V1 版本把一些 chunked prefill 还有 prefix caching 的内容拆离出去,做得比较通用。

vLLM 实现了一种所有或无(all-or-nothing)驱逐策略,即要么驱逐一个序列中的所有块,要么不驱逐任何块。

接下来看看来自 v1/core/scheduler.py 的 V1 版本的 schedule 实现。

Scheduler 有个 waiting list 和 running list(位置代表权重,是 FIFO 的)。

从 running list 中获取 request 然后通过 kv_cache_manager 执行 append_slots 把新的 block 追加到 request 的 block chain 当中。如果最后一个 block 的 slot 还够的话,就不会追加新的 block。

1
2
new_blocks = self.kv_cache_manager.append_slots(
request, num_new_tokens)

如果当前的 kv_cache 的 block table 满了,则会抢占一个 running list 中的 request(放入 waiting list 中)并且把他的 cache block 都 free 掉,这里的 free 是引用计数的形式,如果引用计数为 0 就会被释放,但如果多个 request 共享了一个 block 就还不会被真正释放。

1
2
3
4
5
6
7
8
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req)

加入到 scheduled_running_reqs 中,消耗这次调度的 token budget,这个 budget 用完以后就会停止调度了。

1
2
3
4
5
6
7
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1

如果没有抢占请求则说明还是比较富裕的,尝试从 waiting list 中获取 request,waiting list 可能有新请求也可能有之前被抢占的请求,然后执行一遍上面的代码,不同的是需要从 kv_cache_manager 计算 computed_tokens,因为被之前被抢占的或者一些有共同前缀的 kv cache block 是已经缓存过的。

1
2
3
4
5
6
7
8
request = self.waiting[0]
# Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size

最后把每个 request 分配到的 tokens 数量记录到 SchedulerOutput 当中。

update_from_output 接受 SchedulerOutputModelExecutorOutput,更新 request 的状态,例如更新已经计算的 token 数量,更新 kv cache 的 block 等。对于每个请求都会检查 request.num_computed_tokens == request.num_tokens 从而判断是否已经开始 decode 的部分了。然后构造 EngineCoreOutput,并且检查是否需要停止这个 request。_check_stop 方法会检查是否已经生成了 eos token 或者已经达到了最大长度,并且 free 掉对应的 request。所有没有 stop 的 request 会重新加入到 running 队列中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.append_output_token_ids(token_id)
num_new_tokens = 1
# TODO: Update the KV cache manager for prefix caching.

# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
stopped = self._check_stop(request)

# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
request_id=req_id,
new_token_ids=request.output_token_ids[-num_new_tokens:],
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)

# Breakout of the loop.
if stopped:
continue

SchedulerOutput

该类包含了调度器的输出信息。以下是各个字段的作用:

  • scheduled_new_reqs: List[NewRequestData]

    • 作用:存储新请求的数据列表,这些请求是刚刚被调度的。
  • scheduled_resumed_reqs: List[ResumedRequestData]

    • 作用:存储恢复请求的数据列表,这些请求是之前被暂停,现在重新被调度的。
  • scheduled_running_reqs: List[RunningRequestData]

    • 作用:存储正在运行请求的数据列表,这些请求在当前调度周期内继续运行。
  • num_scheduled_tokens: Dict[str, int]

    • 作用:存储每个请求调度的token数量,键是请求的ID,值是对应的token数量。
  • total_num_scheduled_tokens: int

    • 作用:存储所有请求调度的token总数。
  • scheduled_encoder_inputs: Dict[str, List[int]]

    • 作用:存储每个请求的编码器输入,键是请求的ID,值是对应的编码器输入列表。
  • preempted_req_ids: Set[str]

    • 作用:存储被抢占的请求ID集合,这些请求在当前调度周期内被暂停。
  • finished_req_ids: Set[str]

    • 作用:存储已完成的请求ID集合,这些请求在当前调度周期内完成。
  • free_encoder_input_ids: List[Tuple[str, int]]

    • 作用:存储空闲的编码器输入ID列表,每个元素是一个元组,包含请求ID和对应的编码器输入ID。

这些字段共同描述了调度器在一个调度周期内的所有操作和状态变化。

KVCacheManager

来自v1/core/kv_cache_manager.py,这是v1版本的实现。

kv cache比较简单,
这个博客中的图片很好地阐述了kvcache的作用。

但涉及PagedAttention的实现,就需要管理block。这类似于操作系统中的虚拟地址、页表和物理页的关系。

PagedAttention的主要思想是基于操作系统中分页(paging)的经典概念。传统的注意力算法通常要求keys和values在内存空间中连续存储,
而PagedAttention则允许在非连续的内存空间中存储keys和values。

PagedAttention将每个序列(sequence)的KV缓存(KV cache)分成固定大小的块(block)。
每个块包含一个固定数量的token的key和value向量。这意味着,即使keys和values不连续存储,也可以有效地访问和操作它们。

block的管理会有一个类似于页表的结构,用于映射block的逻辑地址到物理地址。

论文中的这个图很好的表示了他们的关系,如果新生成的token填满了当前block就会分配一个新的block用于新token的生成。

共享的prefix cache指的是提示词的前缀一样的情况,他们的位置编码也不变的情况下可以在不同的sequence之间共享。
例如对于一个英语到法语翻译的提示词,前面有很多事可以共享的,对于跨请求的kv cache来说可以基于这个前缀来共享kv cache的block。

序列 前缀 (Prefix) 输入任务 (Task Input) 完整提示 (Complete Prompt) LLM 输出 (LLM Output) 输出任务 (Task Output)
Sequence A Translate English to French:
“sea otter” => “loutre de mer”
“peppermint” => “menthe poivrée”
“plush giraffe” => “girafe en peluche”
“cheese” => Translate English to French:
“sea otter” => “loutre de mer”
“peppermint” => “menthe poivrée”
“plush giraffe” => “girafe en peluche”
“cheese” =>
fromage fromage
Sequence B Translate English to French:
“sea otter” => “loutre de mer”
“peppermint” => “menthe poivrée”
“plush giraffe” => “girafe en peluche”
“I love you” => Translate English to French:
“sea otter” => “loutre de mer”
“peppermint” => “menthe poivrée”
“plush giraffe” => “girafe en peluche”
“I love you” =>
Je t’aime Je t’aime

free_block_queue是一个链表,用于分配block,初始化时将block_pool中的所有blocks串起来。它通过链表实现对KVCacheBlock的管理,删除操作是O(1)的,没有使用标准库中的dequeue。

KVCacheBlock除了prev和next指针,还有ref_count_block_hash,用于prefix caching的计算。其key是父block的hash和当前block的tokens ids的hash。

block_pool代表物理block的映射关系,例如0 -> 第一块block

1
2
3
4
# A Block pool of all kv-cache blocks.
self.block_pool: List[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]

cached_block_hash_to_block 保存的数据结构是 {block_hash: {block ID: block}}

req_to_blocks 保存了 request到 block列表的映射关系,{request ID: [block ID]}

block的eviction的定义,eviction candidate == in free queue and ref_cnt == 0

get_computed_blocks方法

根据request获取已经计算过(缓存过)的block,获取kv cache blocks的方式是通过block hash 从cached_block_hash_to_block寻找的。
hash的计算是之前的block hash加上当前token ids做一次hash,第一个block则没有父block只用当前自己的token ids做hash。

append_slots方法

会为需要新计算的token ids分配block(如果现有的block不够的话)。

Worker

GPUModelRunner

v1/worker中的gpu_runner.pyv1版本的实现。

首先依赖一个大的config参数vllm_config,包含了model_configcache_configscheduler_configdevice_config等。

初始化kv cache的dtype,对照表如下,half就是fp16,float就是fp32,默认是和模型的dtype一样。

1
2
3
4
5
6
7
8
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}

初始化sliding_window的配置,这个东西在Qwen里面才用到。

初始化block_size,决定了kv cache中连续保存的token的数量,也就是PagedAttention中的那个block的大小,Prefix cache也是以block为维度的。

初始化kv_heads,这个决定了kv head的数量,如果指定了 tensor_parallel_size,会根据这个参数平均分给每个GPU。

初始化head_size,基于model config,是model config里面的head_dim

初始化hidden_size,就是model config里面的hidden_size,就是d_model或者embed_dim,代表同一个长度。

初始化kv_cache

初始化encoder_cache encoder结果的缓存。

初始化input_registry 和多模态的支持有关系。

初始化requests dict用于request的状态保存,这里的request就是一个文本的sequence。

初始化InputBatchmax_num_seq决定了batch的宽度,max_model_len决定了batch的长度。这个Batch对象负责管理在用于前向传播的batch当中的request的插入和删除。

初始化use_cuda_graph 这个由 enforce_eager 决定,默认是会加载整个计算图。

初始化positions: torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device)

初始化input_embeds,可以看到,宽度是max_num_tokens,长度是hidden_size,这个是用来存储输入的embedding的。

1
2
3
4
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)

InputBatch

InputBatch在整个工程中负责管理和处理批量输入请求,确保请求的高效处理和管理。

execute_model 方法

execute_model是整个schedule | compute | update循环中的核心部分,负责执行模型的前向传播。

_update_states方法

在每次运行一个 batch 时,会根据调度器(scheduler)的要求调整每个 batch 中请求的优先级。调度器会更新请求的状态缓存 id -> CachedRequestStateinput_batch 的缓存,移除被抢占和停止的请求,并将新加入的请求放入 batch 中。因此,runner 只负责执行,具体的策略由调度器决定。

CachedRequestState 记录了请求 ID、使用的缓存块 ID 以及已计算的 token 数量。

_excute_encoder方法

执行多模态中的encoder,对于新的多模态的encode,调用model.process_mm_inputs存入到encoder_cache当中。

self.model.compute_logits使用
vllm/model_executor/layers/logits_processor.py中的LogitsProcessor,从hidden_states计算logits

self.model.sample使用vllm/model_executor/layers/sampler.py中的sampler进行sample。

最终得到sampled_token_ids = sampler_output.sampled_token_ids

_gather_encoder_outputs 方法

encoder_cache中获取当前batch需要用到的encoder的输出。

_prepare_inputs方法

input_batch.block_table 在 GPU 上,而 input_batch.block_table_cpu_tensor 在 CPU 上。
前面提到 batch 的整理是在 CPU 上进行的,这里是将要推理的部分拷贝到 GPU 上的 block_table 中。由于使用了 PagedAttention,因此所有的序列都是按 block 为粒度进行切分的。

获取input_ids,构造出传给FlashAttention的数据,例如block_table,和query_start_locseq_start_loc用于定位query和seq的位置。

input_ids, attn_metadata, logits_indices

_prepare_sampling方法

构造出sampling的参数,获取每个request的temperaturetop_ktop_p等参数。

GPUWorker

v1/worker中的gpu_worker.pyv1版本的实现。
初始化GPUModelRunner,如果开始了VLLM_TORCH_PROFILER_DIR就会调用torch.profiler.profile

determine_num_available_blocks会通过profile的方式决定可以使用的block数量。
然后根据block数量调用Runner的initialize_kv_cache

做一些GPU的dtype支持检查,比如一些老的GPU是不支持bf16的。

FlashAttentionMetadata 包含了input的结构和对应的block table的映射。

FlashAttention是一种新型的注意力算法,它能够准确计算注意力,且只需进行远远少于传统方法的内存访问。这个算法的主要目标是尽可能避免内存的读取和写入,这是注意力机制性能瓶颈的一个关键因素。该论文提出了一个IO-aware的精确注意力算法,它使用tiling(贴瓷砖,代表数据分片)来减少GPU高带宽内存与低带宽内存之间的内存读取/写入次数。

该算法基于注意力矩阵通常稀疏这一观察结果:注意力矩阵只有少数元素非零。它通过将输入矩阵Q、K、V分成更小的块来实现,从而避免了计算全矩阵乘积Q*K^T的内存占用问题。通过块级别的处理,FlashAttention使得矩阵操作可以在现代GPU的内存限制下进行,并仅读取/写入每个切片的非零元素。这降低了需要的内存访问次数,使整个过程更快和更高效。

FlashAttention通过分“瓦片化”的方式计算能够更快的一个原因是将矩阵放入更高速的缓存当中,高速的叫SRAM,低速的叫HBM。

第一代 FlashAttention 只是把QK切片,这个只要把矩阵切分在SRAM,然后计算出结果再存回HBM,这个比较简单。

第二代 FlashAttention 把 softmax 的计算也放在了SRAM上。

源自博客描述的结构中可以看出。

他这里面标得感觉不是很清楚,其中的O_2应该是最终的结果O,里面的l_1/l_2 * A^1 / l_1 就还原出了最终结果的分母,也就是scale法则。

第三代 FlashAttention 减少了上面提到的scale,不再每一步做除法,而是放到最后再除。还有就是针对交叉注意力中的mask的优化,跳过了被mask的部分。还有就是CUDA Thread warps的优化提高了并行度。

总结

FlashAttention通过利用高速缓存和分块技术,显著减少了内存访问次数,提高了注意力计算的效率。第一代主要通过切分QK矩阵并利用SRAM缓存,第二代将softmax计算也放入SRAM,第三代则进一步优化了scale计算和mask处理,并提升了并行度。

vLLM 基于 uvicorn + FastAPI 的异步 Web 框架构成。vLLM 的主体是 LLMEngine,它是一个单例类,负责管理所有的模型和数据。在异步 API 中使用的是一个 AsyncEngine。在分析 AsyncEngine 之前,我们先将 Web 部分单独拆出来看一下。

vLLM 的 CLI 入口是 vllm/scripts.py,其中 serve 的启动是通过 uvloop.run 的方式启动的。uvloop 是一个替代默认 asyncio 事件循环的库,它使用 libuv 作为事件循环的实现,从而提高性能。uvicorn 是一个基于 uvloop 的 ASGI 服务器,它可以将 ASGI 应用部署到 Web 服务器上。FastAPI 是一个基于 Starlette 的 Web 框架,它提供了许多便利的功能,比如自动文档生成、请求参数校验等。

参数经过解析以后会进入 run_server,通过 uvloop.run(run_server(args))run_serverentrypoints/openai/api_server.py 下面。AsyncEngineArgs.from_cli_args(args) 使用命令行参数初始化 AsyncEngineArgs,如果要自行封装的话可以直接初始化 AsyncEngineArgsAsyncEngineArgs 继承自 EngineArgs,其中的参数都是用来控制推断命令的。

比较常用的几个参数:

  • model: 模型的路径,可以是一个目录,也可以是 hf 上的一个 repo。

  • model_name: 如果是目录的话,期望的模型名称,或者想要改个别名,对应的是 API 中指定模型的名称。

  • tensor_parallel_size: tensor parallel 副本数,如果用多个 GPU 可以用到,会根据这个将 kv head 平分到不同的 GPU 上。

  • pipeline_parallel_size: pipeline stages 数,如果用多个 GPU 可以用到,会根据这个将模型的前向计算的layers分成多个阶段,每个阶段在不同的 GPU 上计算。

    可以参考下面这个例子:

    假设我们有 8 个 GPU,分别表示为 g0 … g7,并且我们使用 2 个 GPU 来并行化模型张量,使用 4 个 GPU 来并行化模型流水线。当前函数将创建 4 个张量模型并行组和 2 个流水线模型并行组:

    4 个张量模型并行组:

    • [g0, g1]
    • [g2, g3]
    • [g4, g5]
    • [g6, g7]

    2 个流水线模型并行组:

    • [g0, g2, g4, g6]
    • [g1, g3, g5, g7]

    注意,为了提高效率,调用者应确保相邻的 rank 位于同一个 DGX 盒子上。例如,如果我们使用 2 个 DGX-1 盒子,总共有 16 个 GPU,rank 0 到 7 属于第一个盒子,rank 8 到 15 属于第二个盒子。

  • num_seqs: 最大的序列数,其实就是 batch size,会翻倍得增加显存使用,这个貌似在启动之前的 profile 阶段可能会导致大量显存的占用。

  • quantization: 量化的方法,可以是 bitsandbytes 等,可能需要和 load_format 结合使用。

  • load_format: 加载模型的格式,可以是 pt, safetensors, bitsandbytes 等等,如果用到量化的模型基本要改成 bitsandbytes。

  • dtype: 数据类型,fp32, fp16,bf16 等等,如果模型是 bf16 的话,他默认是 bf16 的模型用 bf16,有些显卡不支持 bf 浮点数所以要设置成 half 也就是 fp16。

  • host: 监听地址。

  • port: 监听端口。

  • max_model_len: 上下文长度,适合显存不足的显卡,把默认的上下文长度改下一点。

  • enforce_eager: 是否强制使用 eager 模式,如果显存不够的需要开启这个模式,不完全加载计算图的方式可以减少显存的使用。

api_server 中的 build_app 会使用 APIRouter 初始化路由,并通过 app.include_router 引入。

主要看 @router.post("/v1/chat/completions") 注册的 async def create_chat_completion 是最常用的函数调用。

init_app_stateapp.state 中保存了 openai_serving_chat,以及其他一些接口的状态,这取决于模型配置中是否包含这些功能。例如,文本嵌入等功能(通常都有)。当调用 create_chat_completion 时,会调用 openai_serving_chat 对应的 OpenAIServingChat 类的方法。因此,Serving 的主体可以通过查看这个对象的方法来理解其功能。

构建 AsyncEngine -> 构建 app 对象。

OpenAIServingChat.create_chat_completion 主体流程

  1. 检查模型

    • 是否支持 model,model 是否是 rola,model 是否是 prompt adapter 等。

      vLLM 的 rola 不是和基座合并在一起的,是支持基座模型加多了个 lora 模型的形式。prompt adapter 看起是多模态架构中的 adaptor。

  2. 从 Engine 中获取 Tokenizer

    • 主要是基于 model path 获取对应的 tokenizer 文件,并初始化对应的 tokenizer。
  3. _preprocess_call:对输入进行预处理

    • resolve_chat_template_content_format:检查对话模板格式,因为每种大模型的用于生成文本的训练数据的格式有所不同,要确认对应的格式,LLAMA 有 LLAMA 的格式,可以参考下面的例子。
    • parse_chat_messages_futures:解析输入的聊天消息,生成一个对话消息列表,变成有类型的对话消息。其中 mm_tracker 要处理 image_urlaudio_url 的消息,会根据构造 placeholderplaceholder 是一个特殊的字符串,用来标记这个位置是一个占位符。llama3.2 用的是 <|image|>
    • apply_{hf,mistral}_chat_template:模板会给提示词添加提示词的开头和结束的标志,从而和实际训练的数据标注对齐,比如 llama3<|eot_id|> 标记结束,padding 等。request_promptengine_prompt 包含 token ids 和多模态数据。
      例如:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    chat = [
    {
    "role": "user",
    "content": [
    {"type": "image"},
    {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}
    ]
    }
    ]

    会变成 <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|>If I had to write a haiku for this one, it would be: <|eot_id|> 中,<|start_header_id|>user<|end_header_id|> 标识 header(也就是 role),<|begin_of_text|> 标识上下文的开头,<|eot_id|> 标识一个消息的结束。除此之外,对于function call的处理,可以参考 examples/tool_chat_template_llama3.2_json.jinja 的一部分可以看出,会把对应工具的调用和提示词加入到用户对话前面,作为 user 的 text 的前缀中的内容形成提示词的一部分上下文。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    {#- Custom tools are passed in a user message with some extra guidance #}
    {%- if tools_in_user_message and not tools is none %}
    {#- Extract the first user message so we can plug it in here #}
    {%- if messages | length != 0 %}
    {%- if messages[0]['content'] is string %}
    {%- set first_user_message = messages[0]['content']|trim %}
    {%- else %}
    {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %}
    {%- endif %}
    {%- set messages = messages[1:] %}
    {%- else %}
    {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
    {%- endif %}
    {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
    {{- "Given the following functions, please respond with a JSON for a function call " }}
    {{- "with its proper arguments that best answers the given prompt.\n\n" }}
    {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }}
    {{- "Do not use variables.\n\n" }}
    {%- for t in tools %}
    {{- t | tojson(indent=4) }}
    {{- "\n\n" }}
    {%- endfor %}
    {{- first_user_message + "<|eot_id|>"}}
    {%- endif %}
    • 请求处理:生成请求的 id request_id = f"chatcmpl-{request.request_id}",确定采样方法 beam_search 还是 sampling,调用 AsyncEngine 的 beam_searchgenerate 方法获得一个 generator。

    • chat_completion_stream_generator 是基于 generator 处理响应,这里主要看 streaming 的部分,同步的请求会直接返回结果。流式响应的格式是多个基于 json 格式的 chunk,类型是 chat.completion.chunk

    1
    {"id": "chatcmpl-1eadb733adf64f5b90114307b2d4d718", "choices": [{"delta": {"content": "", "function_call": null, "refusal": null, "role": "assistant", "tool_calls": null}, "finish_reason": null, "index": 0, "logprobs": null}], "created": 1732869116, "model": "llama3.2", "object": "chat.completion.chunk", "service_tier": null, "system_fingerprint": null, "usage": null}
    1
    {"id": "chatcmpl-1eadb733adf64f5b90114307b2d4d718", "choices": [{"delta": {"content": "AI", "function_call": null, "refusal": null, "role": null, "tool_calls": null}, "finish_reason": null, "index": 0, "logprobs": null}], "created": 1732869116, "model": "llama3.2", "object": "chat.completion.chunk", "service_tier": null, "system_fingerprint": null, "usage": null}
    1
    {"id": "chatcmpl-1eadb733adf64f5b90114307b2d4d718", "choices": [{"delta": {"content": " assistant", "function_call": null, "refusal": null, "role": null, "tool_calls": null}, "finish_reason": null, "index": 0, "logprobs": null}], "created": 1732869116, "model": "llama3.2", "object": "chat.completion.chunk", "service_tier": null, "system_fingerprint": null, "usage": null}
    1
    {"id": "chatcmpl-1eadb733adf64f5b90114307b2d4d718", "choices": [{"delta": {"content": "", "function_call": null, "refusal": null, "role": null, "tool_calls": null}, "finish_reason": "stop", "index": 0, "logprobs": null}], "created": 1732869116, "model": "llama3.2", "object": "chat.completion.chunk", "service_tier": null, "system_fingerprint": null, "usage": null}

    AsyncEngine Client 的 generate 会返回一个异步生成器,result_generator,通过 async for 遍历这个生成器 result,而 result 又是一个 output 的生成器。num_cached_tokens 表示前缀匹配的 kv cache 命中的 token 数量。request.n 代表要生成的选择的数量,一般是 1,如果大于 1 就会生成多个选择的分支,而 response 中的 index 就会代表不同的分支的序号。result 生成器对应的就是多个分支的结果,而 result 中的 output 就代表一个分支中的 chunk。处理过程中会把 output 转化成 ChatCompletionStreamResponse,输出成 data: $json_dump 的 SSE chunk 的形式。stream_options.include_usage 如果设置了的话会在 DONE 之前返回一个 usage stats 的 chunk。

    • tool_parser:解析工具描述。方法和对应的类在 openai/tool_parsers 下面,会根据传入的初始化参数决定对应的解析类。如果对应的 request 有 tool_choice 参数,就会使用到 tool_parser,tool_parser 主要用于处理响应中的 tool call 的文本内容。tool_parsertool_choice 为 auto 的时候要调用对应的 extract_tool_calls_streaming 去解析函数调用的文本内容。例如 pythonic_tool_parser 会解释 [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] 这种类似 Python 的文本内容并转化为响应中的 ToolCall 对象。如果是 llama3.1 的 template 的话,参考上面的格式,会把输出 {"name": function name, "parameters": dictionary of argument name and its value} 转化为 ToolCall 对象。

总结

vLLM 的主体是 LLMEngine,它是一个单例类,负责管理所有的模型和数据。在基于FastAPI的异步Restful API 中使用的是一个 AsyncEngine。在交给Engine处理之前会对一些请求参数进行预处理,比如对话模板的格式化,对话消息的解析,模板中的函数调用等。

这篇文章对VLM的架构解释得非常清楚。

一种方法是使用适配器将图片转换为tokens,例如LLaVA使用的prompt based适配器。这种方法类似于RAG的形式,将图片理解的内容补充在对话的上文中。这种适配器会占用LLM的上下文长度,因为图片的tokens会被放入LLM的上文中。目前来说性能会好一些。

另一种方法是基于交叉注意力的适配器,这种方法不会占用LLM的上下文长度,但需要大量参数来达到良好的质量。Llama3.2就是这种结构。

关于Llama3.2本身,它使用了GQA,将kv head分组,多头查询将原本的K和V头分成组并为每个组生成一个共享的Head,这样可以减少kv cache而不太丧失精度(相较于MQA这种只共享一个KV头的方法)。因此,分组多头查询在多头查询注意力和正常多头注意力之间维持了平衡,既考虑了速度,又考虑了输出质量。另一个优化是对一个上下文中的不同文档进行mask处理。由于大模型的上下文现在很长,会将多个文档放入一个上下文中进行训练,但为了避免文档之间的相互影响,需要在文档级别进行mask处理,即当前token不能看到之后的token,也不能看到同一上下文中其他文档的token。其他改动主要是训练规模的调整。

根据Llama3.2的技术报告,里面的image encoder用的是ViT架构。适配器在语言模型和图像编码器之间引入交叉注意力层(cross-attention layers),以提高模型的效率和准确性。交叉注意力层使用通用查询注意力(GQA)并在核心语言模型每四层之后应用。交叉注意力层增加了大量可训练参数,例如Llama 3 405B中约有100B个参数。

本质上,图片编码器的输出通过适配器后作为交叉注意力层的K,文本作为Q,V也来自图片适配器,从而计算文字和图片之间的注意力关系,然后与LLM的输出进行交叉注意力。在训练Llama3.2的适配器时,同时更新了图像编码器的参数,但刻意不更新语言模型的参数。这意味着在适配器训练过程中,Meta只关注图像编码器和适配器的学习,而不影响语言模型的预训练知识。
简而言之,这个适配器在功能上类似于最初的encoder-decoder Transformer中的encoder部分。

在具体的以vLLM推断过程的实现为例,对话的API中会包含{"type":"image","image_url":"uri_of_the_image"},在应用对话模板以后会插入占位符,比如llama3.2用的就是<|image|>,原始的训练中的文本内容会变成类似"<|image|>If I had to write a haiku for this one",以此标记图片的位置信息,实际上需要图片会通过uri_of_the_image被加载到encoder中并携带<|image|>所代表的位置信息编码。

总的来说,VLM的计算过程和推断中的处理方式通过引入适配器和交叉注意力层,实现了图片和文本的高效融合,为多模态任务提供了强大的支持。

前言

在现代数据驱动的世界中,对高效且可扩展的数据存储方案的需求前所未有地强烈。键值数据库因其简洁性和卓越性能而备受推崇。对于那些热爱Rust编程并希望构建自己的键值数据库的开发者来说,这本书将是您的理想起点。笔者将引导您一步一步,从零开始,使用Rust设计和实现一个键值数据库。

数据库领域既神秘又充满魅力。正如俗话所说,不亲手实践,就无法深刻理解。通过构建一个键值数据库,我们不仅能深入掌握这类数据库的设计哲学和实现细节,还能借此机会深化对Rust语言的理解和应用。

本书内容围绕一个基于LSM(Log-Structured Merge-tree)的键值数据库设计和实现展开,参考了LevelDB、RocksDB、PebbleDB、AgateDB和BadgerDB等多个成熟数据库的实现。全书示例代码均使用Rust编写,旨在通过实战演练加深读者对Rust特性的理解,同时探索数据库技术的精髓。

LevelDB相较于其他衍生品,虽然没有一些新的论文和工程实践带来的优化,但保留了最初的设计,其相对简单和完整。其他数据库或多或少都能找到LevelDB的影子。

与《Rust编程语言》一书不同,本书不旨在覆盖Rust语言的所有知识点。我们的目标是实现一个键值数据库,因此读者需要具备一定的系统编程基础,例如文件系统相关的读写调用和指针的使用(不仅在unsafe的情况下,有指针的基础也方便理解引用等概念)。这些知识点如果完全讲解清楚会占据很大的篇幅。本书将解释使用到的Rust语法和特性。即使您没有阅读过《Rust编程语言》,也能跟随本书学习Rust的语法。如果您在阅读本书的Rust语法部分遇到困难,可以参考《Rust编程语言》中的相关章节。笔者也是将《Rust编程语言》作为一本参考书反复阅读,并不需要一次性完全读完,书都是常看常新的。

Rust的设计初衷是确保内存安全,避免常见的程序错误,如空指针解引用。其显著特点包括独特的所有权系统、零成本抽象、可靠的错误处理机制以及完善的工具链,使其在系统编程领域尤为突出。

尽管LevelDB是用C++编写的,Rust在某些方面被视为C++的现代替代品。对于那些对C++有深入了解的开发者而言,转向学习Rust应该会相对轻松。然而,Rust提供了与C++不同的编程范式。例如,迭代器和闭包是Rust标准库和语法的一部分,作为语言的核心部分,引入了多种语法糖来支持这些功能,这可能会让熟悉Python的开发者感到亲切。

经典的内存错误包括使用已释放内存的指针、向量长度被修改但另一个引用仍保留原长度信息导致访问不确定内存地址等。Rust的所有权系统在编译阶段就能避免这些问题。

Rust借鉴了函数式编程的多种技巧,为开发者提供了一种既熟悉又新颖的编程体验,例如模式匹配、迭代器、闭包、泛型等。这些特性使得Rust在编写高效、安全和易维护的代码方面具有独特优势。

与有GC的语言相比,Rust可能让使用者感到不适应,许多在其他语言中理所当然的写法在Rust中行不通。Rust对指针的可变性有明确限制,并且只允许存在一个可变引用。这种所有权的检查使编译器变得非常严格,在一定程度上增加了编程的复杂性。然而,这也是Rust的优势之一,它能够在编译阶段发现许多潜在的错误。

相比暴露指针的语言,Rust对内存的解引用有严格的检查。尽管所有权系统有时会让代码显得冗长,但Rust提供了许多有趣的语法和特性,如模式匹配、错误处理宏、默认返回末端表达式等,使得编写Rust代码既轻便又高效。

Rust的性能非常出色,部分Linux内核驱动和Windows安全模块已经采用Rust实现。AWS也在许多地方使用Rust,飞书客户端的一部分代码也使用了Rust,这在一定程度上证明了Rust在系统编程领域的优异表现。如果需要选择一种新语言开发消息队列、数据库、文件系统等软件,Rust是一个非常不错的选择。这也是本书使用Rust实现的原因之一,以展示Rust在这些方面的优势。

在错误处理方面,Rust采用?问号宏简化了传统的错误处理流程,相比Go语言中显式处理错误的if err != nil {}模式是一种进步。这种简洁的错误处理方式不仅提高了代码的可读性,也加速了开发过程。

这些特色贯穿本书始终,在随后的章节中,您将探索到Rust的更多有趣特性。

笔者是一名Rust初学者,里面的很多实现可能存在不正确的写法,欢迎指正。

本书面向的读者

  • 希望通过具体项目深入学习Rust,特别是在键值数据库方面的开发者。
  • 对键值数据库的设计和实现感兴趣的初学者,希望通过实践学习相关内容。
  • 想了解LevelDB架构和实现细节的读者,可以选择性地跳过实现部分进行阅读。

如何阅读和使用代码

本书的第一章主要是讲基础概念例如一些基础的数据结构。第二章开始分部分讲解实现的细节。对于
对Rust比较熟悉的读者可以跳过第一章。

第一章Rust

完整讲解Rust的内容将要消耗大量篇幅,也不是本书的目的。本章节主要介绍在实现过程中会会涉及的一些的语法和特性,让读者在阅读代码的时候没有过多障碍。

数据类型

Rust 是一种静态类型的编程语言,其数据类型可以分为两大类:原始类型(Primitive Types)和复合类型(Compound Types)。以下是 Rust 中常见的数据类型:

原始类型(Primitive Types)

  • 整数类型(Integer Types):表示整数。有符号整数包括 i8i16i32i64i128,无符号整数包括 u8u16u32u64u128

    1
    2
    let signed_integer: i32 = -42;
    let unsigned_integer: u64 = 42;
  • 浮点数类型(Floating-Point Types):表示小数。Rust 有两个浮点数类型:f32f64

    1
    2
    let float32: f32 = 3.14;
    let float64: f64 = 3.14;
  • 布尔类型(Boolean Type):表示逻辑值,只有两个可能的值:truefalse

    1
    2
    let is_true: bool = true;
    let is_false: bool = false;
  • 字符类型(Character Type):表示单个字符。字符类型使用单引号 '

    1
    2
    let char_a: char = 'a';
    let char_heart: char = '❤';

复合类型(Compound Types)

  • 数组类型(Array Type):表示固定大小的数组。数组中的所有元素必须拥有相同的数据类型。

    1
    let array: [i32; 5] = [1, 2, 3, 4, 5];
  • 元组类型(Tuple Type):表示具有不同数据类型的有序集合。元组的长度是固定的。

    1
    let tuple: (i32, f64, char) = (42, 3.14, 'a');
  • 切片类型(Slice Type):表示对数组或其他集合的引用,但没有固定大小。切片是一种动态大小的视图。

    1
    2
    let array: [i32; 5] = [1, 2, 3, 4, 5];
    let slice: &[i32] = &array[1..4];
  • 字符串类型(String Type):表示动态可变的文本字符串。它由 String 类型表示。

    1
    let my_string: String = String::from("Hello, Rust!");
  • 引用类型(Reference Type):表示对值的引用。引用在 Rust 中被广泛用于实现借用和所有权系统。

    1
    2
    let original_value: i32 = 42;
    let reference: &i32 = &original_value;

这些数据类型提供了灵活性和安全性,通过所有权、借用和生命周期等概念,Rust 的类型系统确保了内存安全和线程安全。在编写 Rust 代码时,正确使用这些数据类型有助于减少运行时错误并提高代码的可维护性。

基本语法

let用于声明变量。

1
let x = 1; // 声明x并赋值为1。

可以使用:显式指定变量类型:

1
let x: i32 = 1;

_表示“存在但不关心”的变量,用于有意忽略某些处理:

1
2
3
4
// 赋值给一个不需要使用的变量
let _ = 1;
// 忽略函数的返回值
let _ = get_thing();

_开头的变量表示暂时忽略以避免编译检查,适合在开发过程中使用:

1
let _x = 1;

let可以“覆盖”变量,使之前相同名称的变量失效,且变量类型可以不同:

1
2
3
let x = 1;
let x = 1 + 2;
let x = "str";

Rust也有元组,相当于固定长度的“容器”可以容纳不同的类型,元组可以指定类型。

1
2
3
let pair : (char, i32) = ('a', 17);
pair.0;
pair.1;

元组适用于解构,下面的代码中some_char'a'some_int是17。结构也可以使用_忽略全部或者其中一部分。解构也适用于函数的返回值。

1
2
3
4
let (some_char, some_int) = ('a', 17);
let (_, some_int) = ('a', 17);
let (_, _) = ('a', 17);
let (left, right) = slice.split_at(middle);

{}可以划分作用域,如果使用之前的覆盖规则,可以在内部作用域覆盖外部作用域的变量。

1
2
3
4
5
6
7
8
9
10
11
fn main() {
let x = "out";
{
// x = "in" 覆盖了外面的"out"
let x = "in";
// 这里会打印"in"
println!("{}", x);
}
// 这里会打印"out"
println!("{}", x);
}

在Rust中,语块也是表达式。

1
2
3
let x = 42;

let x = { 42 };

语块可以包含多个语句,最后一个不以分号结尾的语句是这个语块的值,否则默认等于()

1
2
3
4
5
let x = {
let y = 1;
let z = 2;
y + z
};

函数中也有类似的写法。

1
2
3
4
5
6
7
fn fair_dice_roll() -> i32 {
return 4;
}

fn fair_dice_roll() -> i32 {
4
}

if语句也是表达式。

1
2
3
4
5
6
7
fn fair_dice_roll() -> i32 {
if feeling_lucky {
6
} else {
4
}
}

match语句也是表达式。

1
2
3
4
5
6
fn fair_dice_roll() -> i32 {
match feeling_lucky {
true => 6,
false => 4,
}
}

结构体是用 struct 关键字声明的:

1
2
3
4
struct Vec2 {
x: f64, // 64位浮点数,即 "double precision"
y: f64,
}

它们可以使用结构体字面量初始化,顺序不重要,只有名称重要。:

1
2
let v1 = Vec2 { x: 1.0, y: 3.0 };
let v2 = Vec2 { y: 2.0, x: 4.0 };

还有一种用于从另一个结构体初始化剩余字段的快捷方式,这称为“结构体更新语法”,只能出现在最后位置,并且不能以逗号结束:

1
2
3
4
let v3 = Vec2 {
x: 14.0,
..v2
};

剩余字段也可以是所有字段,这样就可以复制整个结构体,而不是改变所有权:

1
let v4 = Vec2 { ..v3 };

结构体,像元组一样,可以被解构:

1
2
3
let v = Vec2 { x: 3.0, y: 6.0 };
let Vec2 { x, y } = v;
// `x` 现在是 3.0,`y` 现在是 6.0

下面这种形式可以通过..v.y会被忽略掉:

1
let Vec2 { x, .. } = v;

let模式可以用作if中的条件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
struct Number {
odd: bool,
value: i32,
}

fn main() {
let one = Number { odd: true, value: 1 };
let two = Number { odd: false, value: 2 };
print_number(one);
print_number(two);
}

fn print_number(n: Number) {
if let Number { odd: true, value } = n {
println!("Odd number: {}", value);
} else if let Number { odd: false, value } = n {
println!("Even number: {}", value);
}
}

match 也是一种模式匹配,就像 if let 一样:

1
2
3
4
5
6
fn print_number(n: Number) {
match n {
Number { odd: true, value } => println!("Odd number: {}", value),
Number { odd: false, value } => println!("Even number: {}", value),
}
}

match 必须是穷尽的:至少有一个分支需要匹配。

1
2
3
4
5
6
7
8
fn print_number(n: Number) {
match n {
Number { value: 1, .. } => println!("One"),
Number { value: 2, .. } => println!("Two"),
Number { value, .. } => println!("{}", value),
// 如果最后一个分支不存在,我们会得到一个编译时错误
}
}

如果穷尽匹配很难满足,可以使用 _ 作为 “通配符” 模式:

1
2
3
4
5
6
7
fn print_number(n: Number) {
match n.value {
1 => println!("One"),
2 => println!("Two"),
_ => println!("{}", n.value),
}
}

可以给类型声明方法。

1
2
3
4
5
6
7
8
9
10
struct Number {
odd: bool,
value: i32,
}

impl Number {
fn is_strictly_positive(self) -> bool {
self.value > 0
}
}

变量默认不可以改变。

1
2
3
4
5
6
7
8
fn main() {
let n = Number {
odd: true,
value: 17,
};
n.odd = false; // 错误:不能对 `n.odd` 赋值,
// 因为 `n` 没有被声明为可变的
}

不能被重新赋值

1
2
3
4
5
6
7
8
9
10
fn main() {
let n = Number {
odd: true,
value: 17,
};
n = Number {
odd: false,
value: 22,
}; // 错误:不能对不可变变量 `n` 重新赋值
}

可以使用mut关键字来声明可变变量。

1
2
3
4
5
6
7
fn main() {
let mut n = Number {
odd: true,
value: 17,
};
n.odd = false; // 没问题:`n` 是可变的
}

特征(Traits)在Rust中实现了类似其他语言中的多态功能:

特征定义了一组可以由多种类型共享的行为契约。其定义如下:

1
2
3
trait Signed {
fn is_strictly_negative(self) -> bool;
}

任何满足这些条件的类型都可以实现这个特征:

1
2
3
4
5
impl Signed for Number {
fn is_strictly_negative(self) -> bool {
self.value < 0
}
}

这样,Number 类型就拥有了 is_strictly_negative 方法。特征也可以包含默认实现:

1
2
3
4
5
6
7
trait Signed {
fn is_strictly_negative(self) -> bool {
self.value() < 0
}

fn value(&self) -> i32;
}

然后,类型只需要提供那些没有默认实现的方法:

1
2
3
4
5
impl Signed for Number {
fn value(&self) -> i32 {
self.value
}
}

Rust 的一个核心特征是 Drop,它允许你定义当值离开作用域时应该发生的事情:

1
2
3
4
5
impl Drop for Number {
fn drop(&mut self) {
println!("Dropping {}", self.value);
}
}

Number 实例离开作用域时,Rust 会自动调用 drop 方法。

枚举(Enums)允许你定义一个类型,该类型可以是多个不同变体中的一个。这对于值可以有多种但数量有限的类型特别有用:

1
2
3
4
5
6
7
enum WebEvent {
PageLoad,
PageUnload,
KeyPress(char),
Paste(String),
Click { x: i64, y: i64 },
}

与结构体一样,枚举的每个变体可以包含不同类型和数量的数据。你可以使用 match 表达式来操作枚举值:

1
2
3
4
5
6
7
8
9
fn inspect(event: WebEvent) {
match event {
WebEvent::PageLoad => println!("page loaded"),
WebEvent::PageUnload => println!("page unloaded"),
WebEvent::KeyPress(c) => println!("pressed '{}'", c),
WebEvent::Paste(s) => println!("pasted \"{}\"", s),
WebEvent::Click { x, y } => println!("clicked at x={}, y={}", x, y),
}
}

枚举也可以有方法:

1
2
3
4
5
6
7
8
9
10
11
impl WebEvent {
fn describe(&self) -> String {
match self {
WebEvent::PageLoad => String::from("page loaded"),
WebEvent::PageUnload => String::from("page unloaded"),
WebEvent::KeyPress(c) => format!("pressed '{}'", c),
WebEvent::Paste(s) => format!("pasted \"{}\"", s),
WebEvent::Click { x, y } => format!("clicked at x={}, y={}", x, y),
}
}
}

没有返回值的空函数:

1
2
3
fn greet() {
println!("Hi there!");
}

右箭头表示返回值类型:

1
2
3
fn foo() -> i32 {
1
}

模块管理

在Rust中,模块是用于组织代码、控制可见性以及支持代码重用的重要概念。Rust的模块系统是基于文件和目录组织的,这使得代码的组织变得清晰而灵活。下面是Rust模块管理的一些关键概念:

模块定义

模块通过mod关键字进行定义,可以在一个Rust文件中定义一个模块。例如:

1
2
3
4
// 在文件 mod_example.rs 中定义了一个模块
mod example {
// 模块的内容
}

模块路径

模块路径用于指定模块的位置。Rust使用::来表示模块路径。例如:mod_example::example

Rust的模块系统与文件系统有很强的映射关系。一个模块可以对应于一个文件,也可以对应于一个目录,包含多个文件。这使得项目的文件和目录结构能够与代码组织一致。

pub关键字

在Rust中,使用pub关键字来标识模块、结构体、枚举、函数等的可见性。只有被标记为pub的项才可以在其他模块中被访问。

1
2
3
4
5
6
7
8
// 在 example 模块中声明了一个公共的结构体
mod example {
pub struct MyStruct {
// 结构体的字段
}
}
// 在其他模块中使用 example 模块中的 MyStruct
use example::MyStruct;

mod.rs文件

文件本身是可以被作为模块引用的,这样可以更好地组织代码。
如果一个模块的内容比较复杂,可以在模块所在的目录中创建一个mod.rs文件,作为模块的“命名空间”,用于存放模块的具体实现。例如:

1
2
3
// 在 example 目录中创建 mod.rs 文件
// example/mod.rs
pub mod sub_module;

使用方式:

1
2
// 在其他模块中引用 example 模块
use example::sub_module;

这个目录用于存放模块的具体实现。这有助于清晰地分离模块的定义和实现。

cratesuper

crate关键字用于表示当前crate的根模块,而super关键字用于表示当前模块的父模块。

1
2
3
4
5
6
7
8
9
10
11
12
// 在 crate 根模块中
mod my_module {
// 在 my_module 中
mod sub_module {
// 在 sub_module 中,使用 super 表示 my_module
use super::super::my_function; // 调用父模块的函数
}
}

fn my_function() {
// 函数实现
}

模块的可见性规则

默认情况下,模块和其中的项对外部是不可见的。可以通过pub关键字调整可见性。Rust的模块系统强调了显式性,即除非明确指定为pub,否则默认情况下所有项都是私有的。

pub 有多种用法,包括:

  • pub:在默认情况下,Rust 中的项是私有的,只能在定义它们的模块中访问。使用 pub 关键字可以将项声明为公共的,使其在整个 crate 中都可见。
1
2
3
4
5
6
7
pub struct MyStruct {
pub field: i32,
}

pub fn my_function() {
// 函数实现
}

在这个例子中,MyStructmy_function 都被声明为公共的,可以在 crate 的任何地方访问。

  • pub(crate):限制了项的可见性仅在当前 crate 中。这使得项在 crate 外部是不可见的,但在同一个 crate 内的所有模块都可以访问。
1
2
3
4
5
6
7
pub(crate) struct InternalStruct {
// ...
}

pub(crate) fn internal_function() {
// 函数实现
}

在这个例子中,InternalStructinternal_function 只能在定义它们的 crate 中的任何模块中访问。

  • pub(super):限制了项的可见性仅在其父模块(即包含该项的模块)和其父模块的子模块中。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
mod parent_module {
pub(super) struct SuperStruct {
// ...
}

pub fn super_function() {
// 函数实现
}

mod child_module {
fn inner_function() {
// 在子模块中可以访问 SuperStruct 和 super_function
let my_struct = SuperStruct { /* ... */ };
super_function();
}
}
}

在这个例子中,SuperStructsuper_functionparent_module 中可见,但在 crate 中的其他模块不可见。

  • pub(self):限制了项的可见性仅在当前模块中。这使得项对于同一模块中的其他模块是不可见的。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
mod my_module {
pub(self) struct ModuleStruct {
// ...
}

pub(self) fn module_function() {
// 函数实现
}

mod submodule {
fn inner_function() {
// 在子模块中不能访问 ModuleStruct 和 module_function
// 这两个项对于同一模块中的其他模块是不可见的
}
}
}

在这个例子中,ModuleStructmodule_function 只能在 my_module 中的任何模块中访问。

这些可见性修饰符允许 Rust 程序员精确地控制项的可见性,从而确保代码结构的封装和安全性。

use指令

use 指令可用于将其他命名空间的名称 “引入作用域”:

1
2
3
use std::cmp::min;

let least = min(7, 1); // 1

也可以用紧凑的写法:

1
2
3
4
5
6
7
// 格子单独引入
use std::cmp::min;
use std::cmp::max;
// 从cmp分开
use std::cmp::{min, max};
// 从std分开也可以
use std::{cmp::min, cmp::max};

*可以通配引入:

1
use std::cmp::*;

Rust的模块系统是一个强大的组织和抽象工具,支持创建清晰、可维护、可重用的代码结构。了解和熟练使用模块系统有助于提高代码的可读性和可维护性。

错误类型和可选项

ResultOption 是 Rust 中用于错误处理和可选值的两个重要枚举类型。它们在处理不同类型的情况时非常有用。

Option 枚举类型

Option 类型用于表示一个值可能存在(Some)或不存在(None)。它通常用于返回一个可能为空的值。

1
2
3
4
enum Option<T> {
Some(T),
None,
}

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
fn find_element(vec: &Vec<i32>, index: usize) -> Option<i32> {
if index < vec.len() {
Some(vec[index])
} else {
None
}
}

let numbers = vec![1, 2, 3];
match find_element(&numbers, 1) {
Some(value) => println!("Found: {}", value),
None => println!("Not found"),
}

Result 枚举类型

Result 类型用于表示一个操作可能成功(Ok)或失败(Err)。它通常用于返回一个可能会出错的操作结果。

1
2
3
4
enum Result<T, E> {
Ok(T),
Err(E),
}

示例:

1
2
3
4
5
6
7
8
9
10
11
12
fn divide(a: i32, b: i32) -> Result<i32, String> {
if b == 0 {
Err(String::from("Division by zero"))
} else {
Ok(a / b)
}
}

match divide(4, 2) {
Ok(result) => println!("Result: {}", result),
Err(e) => println!("Error: {}", e),
}

ResultOption 的关系

  • Option 用于表示一个值可能存在或不存在,而不涉及错误信息。
  • Result 用于表示一个操作可能成功或失败,并且可以携带错误信息。

在某些情况下,可以将 Option 转换为 Result,例如在需要提供错误信息时:

1
2
3
4
5
6
fn find_element(vec: &Vec<i32>, index: usize) -> Result<i32, String> {
match vec.get(index) {
Some(&value) => Ok(value),
None => Err(String::from("Index out of bounds")),
}
}

通过这种方式,可以更灵活地处理错误和可选值。

自定义错误类型

在 Rust 中,通常建议使用自定义的错误类型来更好地表达错误信息。可以通过枚举或结构体来定义自己的错误类型,并实现 std::fmt::Debugstd::fmt::Display trait 来提供可读的错误信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#[derive(Debug)]
enum MyError {
DivisionByZero,
CustomError(String),
}

impl std::fmt::Display for MyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MyError::DivisionByZero => write!(f, "Cannot divide by zero"),
MyError::CustomError(msg) => write!(f, "Custom error: {}", msg),
}
}
}

fn divide(a: f64, b: f64) -> Result<f64, MyError> {
if b == 0.0 {
Err(MyError::DivisionByZero)
} else {
Ok(a / b)
}
}

fn main() {
match divide(10.0, 0.0) {
Ok(result) => println!("Result: {}", result),
Err(error) => println!("Error: {}", error),
}
}

使用 ? 操作符

Rust 中的 ? 操作符可以用于快速地将 ResultOption 的值传递给包含错误处理的函数。它简化了错误传播的代码。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fn operation1() -> Result<i32, &'static str> {
// ...
Ok(42)
}

fn operation2() -> Result<i32, &'static str> {
// ...
Ok(10)
}

fn main() -> Result<(), &'static str> {
operation1()?;
operation2()?;
Ok(())
}

错误处理 thiserror 和 anyhow

错误处理
anyhow提供了统一管理error的方式,任何error都可以存储在anyhow中。
thiserror提供了方便我们定义error的宏。

我们的Result类型都是anyhow::Result并且通过thiserror的宏来自定义错误。

单元测试

通过配置宏 #[cfg(test)],我们可以指定某个模块为测试模块,并且可以为模块内的函数配置 #[test] 以指定某个函数为测试实例。在 Rust 中,习惯性地会创建一个与模块同级的名为 tests 的模块,然后在该模块中编写测试函数。这些测试代码一般位于与源代码相同的文件中(也可以分成独立的文件编写)。下面的示例代码演示了如何简单验证两个数的相加。在本书中,测试代码按照类似的格式提供,旨在验证实现的正确性,并一定程度上提供对应函数或方法的使用示例。assert_eq! 是一个宏函数,用来断言相等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 源代码模块
mod my_module {
pub fn add(a: i32, b: i32) -> i32 {
a + b
}
}

// 测试模块
#[cfg(test)]
mod tests {
use super::my_module;

// 测试函数
#[test]
fn test_add() {
assert_eq!(my_module::add(2, 3), 5);
}
}

本书的代码按章节模块化组织,每个章节都可以独立运行。使用 cargo test 命令,可以按章节前缀运行对应的单元测试(例如 cargo test ch1),也可以执行具体的测试函数(例如 cargo test ch1::skiplist::tests::it_works)。章节内容循序渐进,每章结束时会引入前面章节的模块。这种组织方式使读者能够逐步学习和理解每个模块的工作原理。每个章节相对独立,读者也可以跳跃式阅读。

读者可以根据需要修改每个独立模块,尝试理解其工作原理,或在自己实现过程中参考这些模块。这种结构旨在提供灵活性,使读者能够自由地使用和探索本书的代码。

所有权和引用

所有权是由编译器检查的,因此检查会非常严格。在Rust中,浅拷贝会移交所有权(move),而深拷贝(Copy)则会复制对象,从而避免所有权冲突。如果一个对象实现了Copy trait,也可以进行复制,不会与所有权产生冲突。在Rust中,只有copy和move两种操作。

引用不会获得所有权,因此也没有权利调用Drop。创建引用在Rust中被称为借用,因此借用和引用有时会混用。借用被视为一个动词,而引用被视为一个名词。如果希望在不产生Copy的情况下修改一个对象,可以使用可变引用。

Rust规定一个对象只能有一个可变引用,且不能同时存在其他的可变或不可变引用。

笔者推荐更详细的内容可以阅Rust Book
Rust nomicon,这两本书都是比较全面且标准的Rust教程。

在本书的数据库实现中,我们主要使用字节向量和字节切片,分别用 Vec<u8>&[u8] 表示,代表具有所有权的字节块和对字节块的借用。由于所有权的关系,如果我们只需要读取数据,会使用借用;如果需要保存写入的数据,则会使用具有所有权的对象 Vec<u8>&mut [u8] 是可修改的借用,实际上也具有所有权,当我们需要修改连续内存的一部分时,可以使用这种类型。

如果我们不需要所有权,as_refas_mut 可以为我们提供相应的引用。

Rust约定的迭代器类型如下,注意 IntoIter 有些不同,如果是一个切片的 IntoIter,返回的仍然是引用。

1
2
3
IntoIter - T
IterMut - &mut T
Iter - &T

as_derefas_deref_mut 可以帮助我们自动解多层引用。

避免盲目使用 .clone() 满足借用规则

借用检查器确保 Rust 用户在开发中不会产生不安全的代码。具体而言,它防止了两种情况:首先,只允许存在一个可变引用;其次,允许存在多个引用,但全部都是不可变引用。如果编写的代码不符合这些条件,当开发人员通过克隆变量来解决编译器错误时,就可能陷入这种反模式。

对于初学者而言,使用 .clone() 来解决借用检查器引起的混乱问题是很诱人的。然而,这样做会带来严重的后果。使用 .clone() 会导致数据的复制,两者之间的任何更改都不会同步,就像存在两个完全独立的变量一样。

有一些特殊情况,例如 Rc<T> 被设计成可以智能处理克隆。它在内部管理数据的精确一份拷贝,克隆它将只克隆引用。

还有 Arc<T>,它提供对在堆上分配的类型为 T 的值的共享所有权。在 Arc 上调用 .clone() 会产生一个新的 Arc 实例,它指向堆上与源 Arc 相同的分配,同时增加引用计数。

总的来说,克隆应该是经过深思熟虑的,要充分了解后果。如果使用克隆来消除借用检查器错误,这是可能正在使用这种反模式的一个很好的指示。

即使 .clone() 是一个糟糕模式的指示,有时写效率低下的代码也是可以接受的,比如:

  • 开发者仍然是个新手
  • 代码没有很大的速度或内存约束(比如黑客马拉松项目或原型)
  • 满足借用检查器真的很复杂,而你更愿意优化可读性而不是性能

如果怀疑存在不必要的克隆,应该充分了解《Rust Book》关于所有权的章节,然后评估是否需要这个克隆。

同时,务必在项目中始终运行 cargo clippy,这个 lint 工具会帮你检查一些不必要的 clone

函数借用参数的选择

函数参数会用到大量的借用,因为借用不会产生拷贝,但在使用借用的时候尽量使用直接的借用类型 &str&[T]&T 而不是 &String&Vec<T>&Box<T>。在作为参数的时候,后面的拥有所有权的类型(智能指针)可以自动转换成前面的类型,而反过来则不可以。例如,下面这个函数如果把参数改成 &String 是无法编译的。原因是直接的引用类型需要再分配一个对应的所有权类型才能和所有权类型的引用对齐,但是反过来进行一次解引用就可以获得直接引用。比如 a = Box<T>,其实相当于 &(*a),编译器自动进行了转化先解引用对应的直接类型然后直接引用。反过来的话 a = &T,那就得 &Box::new(*a),需要多创建这个 Box 对象,编译器就直接拒绝了。

1
2
3
4
5
6
7
8
9
fn demo(word: &str) {
}

fn main() {
let ferris = "Ferris";
let curious = "Curious".to_string();
demo(ferris);
demo(&curious);
}

不可变引用和 Rc

不可变引用的生命周期必须小于所有权的生命周期,但Rc不要求,只要最后一个引用离开生命周期则回收。

临时的可变性

可以将可修改对象重新赋值让可变性消失。这样做有一个好处就是如果你明确在之后不想修改该对象,而又人为(可能被合作者,或者两个月后的自己)错误地修改了,编译器就会帮你检查出来。我觉得更多的是在人的“阅读期”直观地明确代码的可变性。

1
2
3
4
5
6
let data = {
let mut data = get_vec();
data.sort();
data
};
// data 是可变的。
1
2
3
4
5
let mut data = get_vec();
data.sort();
let data = data;

// data 是不可变的。

协同性

协同性是Rust里面最难的部分了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
//fn two_cell_refs<'big: 'small, 'small>(
// // NOTE: these two lines changed
// big: Cell<&'big u32>,
// small: Cell<&'small u32>,
//) {
// assign(big, small);
//}

// 如果让mut reference扩大生命周期就会导致垂悬指针。
// Vec可以是因为Vec是有所有权的,所以不会出现垂悬指针。
fn two_refs<'big: 'small, 'small>(big: Vec<&'big u32>, small: Vec<&'small u32>) {
take_two(big, small);
}
fn take_two<T>(_val1: T, _val2: T) {}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
let skl = SkipList { head: None };
}
}

NonNull本质是一个*const T,从而使得NonNull可以与T协变,通过强制转换的方式让这个指针是可修改的。这是标准库中常用的一个对象,目的是让Vec这样的类型使用起来与T具有协变性。

1
2
3
pub struct NonNull<T> {
pointer: *const T,
}

具体的解释可以参考这里,目前笔者也没有完全理解这个概念。

Subtyping is the idea that one type can be used in place of another.

范型

Rust 中的泛型是一种强大的特性,它允许你编写适用于多种数据类型的代码,同时保持类型安全。通过泛型,可以编写更加灵活、抽象和可重用的代码,同时保持 Rust 的内存安全和零成本抽象。

以下是 Rust 中泛型的一些关键概念和用法:

泛型函数

在 Rust 中,你可以编写泛型函数,使其适用于多种类型。示例:

1
2
3
4
5
6
7
8
9
fn print_twice<T>(value: T) {
println!("{:?}", value);
println!("{:?}", value);
}

fn main() {
print_twice("Hello, Rust!");
print_twice(42);
}

在这个例子中,print_twice 是一个泛型函数,可以接受任意类型的参数,并执行相同的打印操作。

泛型结构体

可以为结构体定义泛型类型参数,以实现对不同类型的结构体的抽象。示例:

1
2
3
4
5
6
7
8
9
struct Point<T> {
x: T,
y: T,
}

fn main() {
let int_point = Point { x: 1, y: 2 };
let float_point = Point { x: 1.5, y: 2.5 };
}

在这个例子中,Point 结构体可以用于包含任何相同类型的坐标点。

泛型枚举

枚举也可以包含泛型类型参数,以增加其灵活性。示例:

1
2
3
4
5
6
7
8
9
enum Result<T, E> {
Ok(T),
Err(E),
}

fn main() {
let success: Result<i32, &str> = Result::Ok(42);
let failure: Result<i32, &str> = Result::Err("Error message");
}

在这个例子中,Result 枚举表示可能包含成功结果(Ok)或错误信息(Err),并分别包含了两个泛型类型参数。

泛型实现

可以对泛型类型实现 trait,以为多种类型提供相同的行为。示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
trait Printable {
fn print(&self);
}

impl<T: std::fmt::Debug> Printable for T {
fn print(&self) {
println!("{:?}", self);
}
}

fn main() {
"Hello, Rust!".print();
42.print();
}

在这个例子中,Printable trait 定义了一个 print 方法,然后对所有实现了 Debug trait 的类型实现了这个 trait。

Rust 的泛型提供了强大的抽象能力,帮助编写更加灵活和通用的代码。通过泛型,你能够在不失去类型安全的前提下,减少代码的冗余并提高代码的可维护性。

Read 和 Write

在本书中用的 Trait 比较多的是 Read、Write 和 Seek。

在 Rust 中,Read、Write 和 Seek 是三个与 I/O 操作密切相关的 trait,它们为实现输入和输出操作提供了通用的接口。这三个 trait 可以用于文件、网络套接字等不同的数据源和目标。

Read Trait
Read trait 定义了用于从数据源读取字节的方法。它主要包含一个方法:

1
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error>;

这个方法从实现 Read trait 的类型中读取字节,并将它们存储到提供的缓冲区 buf 中。方法返回一个 Result,其中 Ok(n) 表示成功读取了 n 个字节,Err 表示发生了错误。

Write Trait
Write trait 定义了用于将字节写入数据目标的方法。它主要包含一个方法:

1
fn write(&mut self, buf: &[u8]) -> Result<usize, Error>;

这个方法将提供的缓冲区 buf 中的字节写入到实现 Write trait 的类型中。方法同样返回一个 Result,其中 Ok(n) 表示成功写入了 n 个字节,Err 表示发生了错误。

Seek Trait
Seek trait 定义了用于在数据源中定位和移动读写指针的方法。它包含三个方法:

1
2
3
fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error>;
fn stream_len(&mut self) -> Result<u64, Error>;
fn stream_position(&mut self) -> Result<u64, Error>;
  • seek 方法通过给定的 SeekFrom 枚举类型,将读写指针移动到指定位置。
  • stream_len 方法返回数据源的总长度。
  • stream_position 方法返回当前读写指针的位置。

SeekFrom 枚举有以下几种可能的值:

  • SeekFrom::Start(n):将指针设置到数据源的起始位置加上 n
  • SeekFrom::End(n):将指针设置到数据源的末尾位置加上 n
  • SeekFrom::Current(n):将指针从当前位置移动 n 个字节。

这些 trait 为实现了文件、内存缓冲区等不同类型的数据源和目标提供了通用的接口,使得可以方便地使用相同的 I/O 操作代码处理各种类型的输入输出。在标准库中,例如 FileBufReader 都实现了这些 trait,使得对文件和缓冲区的读写变得简单和灵活。

本书中会大量用到这些 Trait,因为 Vec<u8>fs::File 都实现了这个接口。

Iterator

Iterator用于不可变迭代,IntoIterator用于获取所有权并进行迭代,MutIterator用于可变迭代。
在每个示例中,我们都使用了不同的方法进行迭代,并根据需要进行所有权的转移或可变引用的修改。
Iterator的实现,多种iterator的惯例,
Rust要求所有的集合数据类型都要有如下的迭代器,会返回上述方法。

1
2
3
fn iter(&'a self) -> Items<'a>;
fn into_iter(self) -> ItemsMove;
fn iter_mut(&'a mut self) -> ItemsMut<'a>;

Rust 的标准库为集合类型提供了一组通用的迭代方法,这些方法通常以 Iterator trait 的形式提供。这些方法通常分为三类,即标准的迭代方法 trio:

fn iter(&'a self) -> Items<'a>; 用来遍历&T。

这个方法返回一个不可变的迭代器,允许对集合中的元素进行只读的迭代。返回的 Items 类型是一个迭代器对象,其生命周期与集合本身相同,保证迭代器不会在集合被销毁前失效。

fn iter_mut(&'a mut self) -> ItemsMut<'a>; 用来遍历&mut T

这个方法返回一个可变的迭代器,允许对集合中的元素进行修改。返回的 ItemsMut 类型是一个可变迭代器对象,其生命周期与可变引用的生命周期相同,确保迭代器不会在可变引用结束后继续使用。

fn into_iter(self) -> ItemsMove; 用来遍历T,但如果T是一个引用类型其实和iter是类似的。

这个方法获取集合的所有权并返回一个拥有所有权的迭代器。这表示集合本身将不再可用,因为它的所有权已经转移到迭代器上。返回的 ItemsMove 类型是一个拥有所有权的迭代器对象。

为了提供更大的灵活性和符合 Rust 的所有权模型,这些方法还要求集合类型和对集合的(可变)引用都实现了 IntoIterator trait。这个 trait 提供了一个统一的方式,使得集合类型和引用都能够被用于 for 循环等需要迭代的上下文中。

这种设计使得 Rust 中的迭代更为一致和灵活,同时确保了在迭代过程中对所有权和可变性的严格控制。

两级迭代器(TwoLevelIterator)可以使用标准库的flat_map可以把两级迭代器展开成一个大的迭代器,适用于一些多层次迭代器的场景。

1
2
3
4
5
6
7
8
9
fn main() {
let data = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];

let flat_iter = data.iter().flat_map(|inner| inner.iter());

for &num in flat_iter {
println!("{}", num);
}
}

归并迭代器可以实现两个有序迭代器的归并排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
struct MergeSortIterator<L, R>
where
L: Iterator,
R: Iterator<Item = L::Item>,
{
left: L,
right: R,
left_value: Option<L::Item>,
right_value: Option<R::Item>,
}

impl<L, R> MergeSortIterator<L, R>
where
L: Iterator,
R: Iterator<Item = L::Item>,
{
fn new(left: L, right: R) -> Self {
let mut sorter = MergeSortIterator {
left,
right,
left_value: None,
right_value: None,
};

sorter.fetch_values();
sorter
}

fn fetch_values(&mut self) {
self.left_value = self.left.next();
self.right_value = self.right.next();
}
}

impl<L, R> Iterator for MergeSortIterator<L, R>
where
L: Iterator,
R: Iterator<Item = L::Item>,
L::Item: Ord,
{
type Item = L::Item;

fn next(&mut self) -> Option<Self::Item> {
match (self.left_value.take(), self.right_value.take()) {
(Some(left), Some(right)) => {
if left <= right {
self.left_value = self.left.next();
self.right_value = Some(right);
Some(left)
} else {
self.left_value = Some(left);
self.right_value = self.right.next();
Some(right)
}
}
(Some(left), None) => {
self.left_value = self.left.next();
Some(left)
}
(None, Some(right)) => {
self.right_value = self.right.next();
Some(right)
}
(None, None) => None,
}
}
}



fn main() {
let iter1 = vec![1, 3, 5, 7, 9].into_iter();
let iter2 = vec![10,11, 12, 13, 14].into_iter();

let merge_iter = MergeSortIterator::new(iter1, iter2);

for item in merge_iter {
println!("{}", item);
}
}

幽灵数据

如果SkipIterator使用&'a Link很奇怪,因为Link本身就是个Rc
在引用为0的时候回收,不需要生命周期的标记,但如果不用的话编译器会报错。
我们希望持有SkipNode的引用,生命周期应该和SkipNode一致,所以引入
幽灵数据,标记我们逻辑上关联的对象,因为结构体中没法引用这个类型。

1
2
3
4
5
6
7
8
9
10
11
12
struct SkipIterator<'a> {  
head: Link,
marker: PhantomData<&'a SkipNode>,
}

impl<'a> Iterator for SkipIterator<'a> {
type Item = &'a SkipNode;
fn next(&mut self) -> Option<Self::Item> {
None
}

}

性能调优

TODO

perf profile

tracing

第二章理解键值数据库

键值数据库的介绍

键值数据库(Key-Value Database)在NoSQL(非关系型数据库)范畴中占据重要地位,其采用简洁的键值的结构对数据对象进行存储和检索。
每个数据项由键(key)和关联的值(value)组成,类似于字典或哈希表的数据模型,其中键是唯一标识符,而值则是与之关联的数据。这种简单的键值对模型使得键值数据库在存储和检索简单数据时表现出色。

许多键值数据库支持分布式架构,能够在多个节点上存储数据,以提高性能和可靠性。亚马逊的DynamoDB是一个典型的例子,于1999年提出并成为高度可扩展的键值数据库系统,满足了亚马逊的分布式存储需求,其思想和设计对后来的键值数据库系统产生了深远的影响。

Redis是另一备受欢迎的键值数据库,由Salvatore Sanfilippo于2009年创建。它是一种开源的内存中数据结构存储系统,支持多种数据结构,包括字符串、哈希表和列表,因而成为广泛使用的键值数据库。

LevelDB是由Google开发的高性能键值数据库,采用LSM树(Log-Structured Merge Tree)的结构。在2012年,RocksDB发布,进一步优化了性能和存储效率,受到了广泛的应用。

键值数据库以其快速的读写性能而著称,尤其在需要快速检索特定键的情境下表现出众。它们通常具备横向可扩展性,能够通过添加更多节点来处理更大的负载。其对值的数据结构没有强制规定,因此可以灵活存储各种类型的数据。本书聚焦于单机键值数据库,不会涉及一些集群数据库相关的分布式能力和扩展能力。

LSM

本书将会实现一个类似LevelDB的键值数据库,其核心数据结构是LSM(Log-Structured Merge Tree)。

LSM最早来源于1996年的一篇论文[^1],而被广为人知的契机则是Google的BigTable[^2]论文,其中的文件格式基于LSM。Google开源了类似键值数据库的单机版本,即LevelDB[^3]。传统数据库通常使用B-Tree类的数据结构,它具有许多优点。随着LevelDB的诞生,基于LSM-Tree的数据结构也逐渐进入人们的视野。

LSM是一种用于存储和管理大规模键值对数据的数据结构,它在特定应用场景中非常有效,具有以下优势:

  • 高写入吞吐量:LSM树通过将写入操作追加到顺序写的文件中,并使用内存和磁盘两级存储结构,实现了高效的写入吞吐量。写入操作可以在内存中迅速完成,然后异步地合并到磁盘上的存储文件中。

  • 压缩和合并:LSM树通过定期合并和压缩磁盘上的存储文件,提高了读取性能。这些合并操作使得数据在磁盘上以更为紧凑的形式存储,减少了读取时需要扫描的数据量。

  • 高吞吐读取:LSM树的结构使得范围查询更加高效,因为数据在磁盘上以顺序方式存储。这对于分析型工作负载非常有利。

  • 容错性:由于LSM树的写入操作是追加到预写日志文件中的,这提供了一种容错机制。即使在写入过程中出现故障,系统也可以通过重新应用日志来恢复。

  • 可扩展性:LSM树适用于大规模的分布式存储系统,支持数据的水平扩展。各个节点可以独立地执行写入和合并操作,从而提高了系统的可扩展性。

  • 减少随机I/O:LSM树的追加写入方式减少了磁盘上的随机I/O,有助于提高写入性能。这对于使用磁盘作为主要存储介质的系统尤为重要。

[^1]: 《The Log-Structured Merge-Tree (LSM-Tree)》:这是LSM的原始论文。
[^2]: 《Bigtable: A Distributed Storage System for Structured Data》(作者:Fay Chang等):这是Google的Bigtable论文,该文档介绍了Bigtable如何使用LSM树来管理大规模分布式数据存储。
[^3]: 《LevelDB: A Fast Persistent Key-Value Store》(作者:Jeff Dean, Sanjay Ghemawat):这是Google开发的LevelDB的论文,该数据库使用了LSM树结构。论文提供了对LSM树及其在LevelDB中的应用的深入了解。

HDD和SSD

LevelDB是为了针对HDD的追加写的特性而设计的,有很多优化是基于SSD的。
有一本书关于SSD的《深入浅出SSD》详细阐述了SSD的特性。

第三章构建数据库引擎

基本架构

整个数据库的基本架构下图,一般来说,一个实现LSM键值存储接口所使用的对象是任意字节流。
作为搜索结构,所有数据会有序排列在存储中,常用的操作有插入、更新、获取、遍历和删除等。
为了利用追加写的特性,其中的删除一般是通过插入“墓碑”来代替而不会真正的删除,
而更新则是追加一个键的新版本,所以整个数据库只用到了追加写不使用随机写,充分利用
机械硬盘的追加写性能远远高于随机写的特性。机械硬盘在写之前需要进行磁片上的寻址操作,
这导致随机写相较于顺序写多了很多寻址操作,其之间的性能大致差了100倍。

用到这追加写的文件就是预写日志。可靠的单机数据库需要确保用户调用写入接口返回成功后,即使进程重启(甚至因机器宕机而中断)也不会导致数据丢失。
采用预写式日志是数据库中常见的一种手段,数据会按照先写内存再到日志的顺序进行更新。
由于数据已经持久保存在磁盘上,即使发生异常,内存数据丢失,也能够通过重放预写日志确保数据的完整性。
当前的预写式日志文件会在内存的形式一般叫MemTable。
很多数据库中writer_buffer相关的配置指的就是这个MemTable的大小,因为某种程度上它就是预写日志的内存缓存。

当MemTable达到容量上限(大多数数据库的默认设置为4KB),内存表的内容会被保存在持久化的文件存储中,接着日志文件可以安全删除。
这个表在文件系统上的形式是一个不可更改的搜索结构,一般会用SST(Static Sorted Table),顾名思义就是不可更改的、有序的文件结构。
该文件的不可更改的特性很像Rust变量默认的immutable。内存数据结构需要满足高效的查找和插入,其底层的数据结构一般用SkipList来实现。

数据库会对SSTable进行分层合并,由上层(或上上层)的SSTable合并成新的SSTable,并写入到下一层。
这个过程被称为major compact。因此,层数越小,数据越新,层数越大,数据越久远。

为了限制内存大小,当MemTable达到一定大小后,会转换为不可变内存表。
会作为整个数据库的第0层的SStable,是比较特殊的一层,这个合并过程称为minor compact。LevelDB的由来正是这种分层合并的结构。

当有新的文件产生时需要一个清单文件对这些文件进行记录,一般会使用一个叫MANIFEST的文件保存,用于记录各阶段文件集合信息。
为了更快速的查找,可能还会记录一些附加信息,例如文件大小、最大最小key等。这个文件相当于保存了所有在持久化存储上的SStable的元信息。

对于读操作,需要从内存表、不可变内存表、level-0 SSTable里查找,然后再从level 1中的文件开始查找。

MemTable

MemTable是一个内存中的数据结构,在数据被刷新到SST文件之前保存它们。它相当于一个读写的缓存——新的写总是将数据插入到MemTable中,而读必须在从SST文件读取之前查询MemTable,因为MemTable中的数据是较新的。一旦内存表被填满,它就变成不可变的,并被一个新的内存表所取代。一个后台线程将MemTable的内容刷新到一个SST文件中,之后MemTable就可以被销毁了。MemTable的大小一般是64MB。

SkipList

SkipList(跳表)是一种常用于实现有序存储的数据结构,通常用于构建内存表(MemTable)等应用场景。相比于平衡树等结构,SkipList 不需要复杂的旋转调整来保持平衡,其实现较为简单且易于理解。

SkipList 由一系列节点组成,每个节点包含键值对以及多个层级的指针。节点的高度由一个概率随机决定,这使得 SkipList 在期望上具有 O(log n) 的搜索复杂度。节点结构如下:

1
2
3
4
5
6
struct SkipNode {
key: Vec<u8>,
value: Vec<u8>,
h: usize,
next: Vec<Link>,
}

SkipList的搜索操作是SkipList中的基本操作之一。从头节点开始,逐层向下搜索,如果找到等于目标键的节点,则返回对应值;如果找到大于目标键的节点,就下移一层继续搜索;如果找到小于目标键的节点,就向右移动。这样,通过多层级的指针,可以有效地减少搜索路径。
但是skip list的常数项相对小很多。skip list在空间上也比较节省。一个节点平均只需要1.333个指针(甚至更少),并且不需要存储保持平衡的变量。

图示
对于层级链表,每增加一层链表,节点的搜索路径就会减半,提高搜索效率。以下是一个示例,展示了如何通过层级链表来降低搜索时间复杂度。
对于一个链表来说,搜索的时间复杂度是O(n)的,搜索5需要5次(1,2,3,4,5)。

1
1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8

这个搜索是线性的,但是如果再加一个链表,每隔一个结点取一次,可以节省一半的时间,从最高level的链表开始到小于后置节点时向下一个level搜索,这时5要搜索4次(1,3,5),找到4需要3次(1,3,4)。

1
2
1 - - - - 3 - - - - 5 - - - - 7
1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8

以此类推,再增加一个链表的话,此时5只要搜索2次(1,5),4则没变还是3次(1,3,4)。

1
2
3
1 - - - - - - - - - 5 - - - - - 
1 - - - - 3 - - - - 5 - - - - 7
1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8

这里每个节点的元素只需要存一次,但是每个节点的高度之内都要保存对应指针。

SkipList相较于一些有序数据结构比如平衡树来说不需要做一些旋转的调整来保持树的平衡,节点的高度是基于概率的,如果设置概率为1/2的话可以通俗的理解为“丢硬币,每次正面则这个节点的高度提高一层”,从期望上来说是可以被证明为log(n)的。因为SkipList和链表很接近,相较于平衡树来说手写更容易实现。但不巧的是链表在Rust里面是地狱难度的实现,本章的篇幅因此会比较长。

首先我们通过rand::Rng生成一个u32的随机数,用最低位的连续1的数量作为高度,每次最低位为1时,高度加1,并且随机数右移,模仿丢硬币的过程。

1
2
3
4
5
6
7
8
9
10
11
12
use rand::Rng;
// 1/2的概率生成节点的高度
fn random_height(total: usize) -> usize {
let mut h = 0;
// 生成u32随机数
let mut r = rand::random::<u32>();
while r & 1 == 1 && h < total {
h += 1;
r >>= 1;
}
h
}

首先我们定几个数据结构,Rc是引用计数,节点会被多个前置节点指向,所以我们需要使用Rc包裹我们的节点。
RefCell是一个智能指针,可以将借用的规则推迟到运行时检查,通过borrow_mut方式及时对象没有使用mut声明
也可以获得可修改引用,RefCell会在内部记录可修改引用的唯一。Option用来表示空指针。所以组合起来我们构造了一个Link类型。

1
2
3
use std::cell::RefCell;
use std::rc::Rc;
type Link = Option<Rc<RefCell<SkipNode>>>;

我们围绕的对象都是字节串,所以键值都用Vec来存储,为了突出代码的说明性质不引入过多的复杂性,这里没有使用范型。

1
2
3
4
5
6
const MAX_HEIGHT: usize = 32;

struct SkipList {
head: Link, // head 是一个辅助节点,不存储数据
}

下面两个函数都被包裹在impl SkipList {}中,为了阅读方便进行了省略,创建SkipList的时候会创建一个非None的key和value都为vec![]的占位节点。
从上向下寻找,遇到比自己大的值就下移一层继续寻找,如果遇到比自己小的值就向后移动。
实现链表有两个比较值得一看的文章:

搜索操作是 SkipList 中的基本操作之一。从头节点开始,逐层向下搜索,如果找到等于目标键的节点,则返回对应值;如果找到大于目标键的节点,就下移一层继续搜索;如果找到小于目标键的节点,就向右移动。这样,通过多层级的指针,可以有效地减少搜索路径。
我们的接口参照标准库中collections的形式,insert需要获取key value的所有权,get则返回引用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fn get(&mut self, key: &[u8]) -> Option<Vec<u8>> {
let mut cur = self.head.clone();
for l in (0..MAX_HEIGHT).rev() {
while let Some(next) = cur.clone().unwrap().borrow().next[l].clone() {
if &next.borrow().key[..] > key {
break;
} else if &next.borrow().key[..] < key {
cur = Some(next);
} else {
return Some(next.borrow().value.clone());
}
}
}
None
}

如果是ge的实现,则在level为0时选择大于的值返回。

1
2
3
4
5
6
7
std::cmp::Ordering::Greater => {
// no equal key, just return the first greater key.
if l == 0 {
return Some((next.borrow().key.clone(), next.borrow().value.clone()));
}
break;
}

插入和搜索类似,当找到下一项大于自己或者为None时将自己链接到当前节点之后,如果下一项小于自己则向后移动。插入操作也是 SkipList 中的关键操作。从头节点开始,逐层向下搜索,找到合适的位置插入新节点。为了保持 SkipList 的有序性,需要在每一层级中正确地插入节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
fn insert(&mut self, key: &[u8], value: &[u8]) {
let h = random_height(MAX_HEIGHT);
let new_node = Rc::new(RefCell::new(SkipNode {
key: key.to_vec(),
value: value.to_vec(),
h,
next: vec![None; MAX_HEIGHT],
}));
let mut cur: Link = self.head.clone();
for l in (0..h + 1).rev() {
while let Some(cur_node) = cur.clone() {
let mut cur_node = cur_node.borrow_mut();
match cur_node.next[l].clone() {
Some(next) => {
if &next.borrow().key[..] < key {
// move
cur = Some(next);
} else if &next.borrow().key[..] > key {
cur_node.next[l] = Some(new_node.clone());
new_node.borrow_mut().next[l] = Some(next.clone());
break;
} else {
next.borrow_mut().value = value.to_vec();
break;
}
}
None => {
cur_node.next[l] = Some(new_node.clone());
}
}
}
}
}

链表这类数据结构是Rust中地狱级难度,如果要实现一个通用的链表类的数据结构涉及到协同性(Covariance),标准库中用到了很多unsafe的实现,通过指针的方式简化了代码,例如我们要实现双向链表的话Rc就不适用了,因为指向时成环的。

我们在实现链表的时候是无法依赖编译器帮我们自动释放空间的,因为编译器默认的行为会是一个递归调用Drop。
Box调用Drop以后Node占用的内存被释放,Node中的next的Box就无法被调用Drop了,这在别的语言里面很好实现,只要暂存next就好了。
但是,Box是一个有所有权的指针,显然Rust的所有权系统不允许让自己在“死亡”以后还被别人获取了所有权。

1
2
3
4
5
6
impl Drop for Box<Node> {
fn drop(&mut self) {
self.ptr.drop(); // 先释放指向的对象再释放自己
deallocate(self.ptr);
}
}

所以链表我们要手动实现Drop函数,这里面把box替换出来变成Empty,这样在“死亡”时就可以避免调用next了(毕竟next已经为空了不会有递归调用)。

我们不可以move一个可变引用即使move给了自己,理论上这不算错误,所以需要mem::replace,如果想要实现move就得通过replace的方法。

1
2
3
4
5
6
7
8
9
10
struct Buffer<T> { buf: Vec<T> }

impl<T> Buffer<T> {
fn replace_index(&mut self, i: usize, v: T) -> T {
// error: cannot move out of dereference of `&mut`-pointer
let t = self.buf[i];
self.buf[i] = v;
t
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
impl Drop for List {
fn drop(&mut self) {
let mut cur_link = mem::replace(&mut self.head, Link::Empty);
// `while let` == "do this thing until this pattern doesn't match"
while let Link::More(mut boxed_node) = cur_link {
cur_link = mem::replace(&mut boxed_node.next, Link::Empty);
// boxed_node goes out of scope and gets dropped here;
// but its Node's `next` field has been set to Link::Empty
// so no unbounded recursion occurs.
}
}
}

[并发写入] todo!()

[插入提示] todo!()

预写日志

下面我们展开讲解预习日志的相关知识和实现方式,日志文件切分成了大小为32KB的连续block块,block由连续的log record组成。
预写日志每次追加写都以一个record为单位,record的负载不作具体的规定可以是任意内容。

Record

Record的格式如下,首先是一个4个字节的CRC校验值紧跟着一个两个字节的u16表示长度,然后是一个字节表示类型,最后是实际的负载。

1
2
3
+---------+-----------+-----------+--- ... ---+
|CRC (4B) | Size (2B) | Type (1B) | Payload |
+---------+-----------+-----------+--- ... ---+

如果我们使用u64保存长度是可以存储很大的数据的,RocksDB的一篇数据库键值规模的分析[^4]中提到一般的键是几十个字节,
值是十几KB的量级,u64显得太长了,当大量小对象存在的时候会比较多余,所以u16是比较合适的,但我们希望这个record又是可以扩展保存更长的
数据的,所以如果负载超过限制的32KB的话就会分成多个record存储。
后面会提到批量原子写的相关内容,如果多个key需要同时写入的话这个payload就会比较大了。

[^4]: 《Characterizing, Modeling, and Benchmarking RocksDB Key-Value Workloads at Facebook》

负载长度的扩展

我们按照LevelDB的格式使用Type表示记录的连续性,Log Type有4种:FULL = 1、FIRST = 2、MIDDLE = 3、LAST = 4。FULL类型表明该log record包含了完整的用户数据,用户数据可能比较大,超过了当前block的剩余大小,就需要分成几条log record,第一条类型为FIRST,中间的为MIDDLE,最后一条为LAST。也就是:

  • FULL,说明该log record包含一个完整的用户数据;
  • FIRST,说明是log record的第一条用户数据;
  • MIDDLE,说明是log record中间的用户数据;
  • LAST,说明是log record最后的一条用户数据。
    参考LevelDB文档中的例子,考虑到如下序列的用户数据:
  1. A: length 1000
  2. B: length 97270
  3. C: length 8000

A作为FULL类型的record存储在第一个block中; B将被拆分成3条log record,分别存储在第1、2、3个block中,这时block3还剩6byte,将被填充为0; C将作为FULL类型的record存储在block 4中。

由于一条log record长度最短为7(4个字节的CRC加2个字节的Size加1一个字节的Type),如果一个block的剩余空间小于等于6个字节,那么将被填充为空字符串,长度为7的log record是不包括任何用户数据的空记录。

LevelDB将WAL文件按块划分还有一个好处是能够按块进行切分。对于一些类似MapReduce的处理程序来说比较友好,
按照block读取record直到碰到FULL或者FRIST就可以作为一个分片的边界了。

大小端

大小端(Endian)是计算机体系结构中的一个重要概念,用于描述多字节数据在内存中的存储方式。
它分为两种类型:大端序(Big-Endian)和小端序(Little-Endian),它们的区别在于多字节数据的字节排列顺序。

大端序(Big-Endian):在大端序中,最高有效字节(Most Significant Byte,MSB)位于最低内存地址,而最低有效字节(Least Significant Byte,LSB)位于最高内存地址。这意味着数据的各个部分从高位到低位依次存储。举例来说,十进制数值513在大端序下会以两个字节0x02(512)和0x01(1)的顺序存储。

小端序(Little-Endian):相反,小端序中,最低有效字节(LSB)位于最低内存地址,而最高有效字节(MSB)位于最高内存地址。这意味着数据的各个部分从低位到高位依次存储。以相同的例子,十进制数值513在小端序下会以两个字节0x01和0x02的顺序存储。

1
2
3
4
5
6
7
8
// Convert to little-endian bytes
let le_bytes = 513u16.to_le_bytes();
println!("Little-endian bytes: {:?}", le_bytes);
// Output: Little-endian bytes: [1, 2]
// Convert to big-endian bytes
let be_bytes = 513u16.to_be_bytes();
println!("Big-endian bytes: {:?}", be_bytes);
// Output: Big-endian bytes: [2, 1]

这两种字节序的差异可能在跨平台数据交换或者数据解释方面引发问题。例如,当你从一个大端序计算机向一个小端序计算机传输数据时,需要进行字节序转换,以确保数据正确解释。这样的差异在网络通信、文件存储、数据传输等领域都有广泛的应用。

计算机体系结构、操作系统和编程语言通常会规定默认的字节序,但程序员和开发人员需要了解大小端的概念,以确保数据在不同系统之间正确传递和解释。在考虑兼容性的情况下,我们可以选择固定一种大小端的方式。在本书中,我们选择小端编码,这也是X86的CPU的模式。

CRC

CRC(Cyclic Redundancy Check,循环冗余校验)是一种高效的检错码,用于数据的校验。它基于模二除法进行计算,通过计算结果的余数来生成校验值。模二除法相当于是一种不借位的减法,因为0减1在模二运算后仍为1。在二进制领域,模二运算等效于异或操作,因此CRC可以通过位移和异或的方式进行快速计算。

CRC32是一种广泛使用的CRC类型,其余数为32位,对应的除数是33位的多项式。该多项式表示为:X^{32}+X^{26}+X^{23}+X^{22}+X^{16}+X^{12}+X^{11}+X^{10}+X^8+X^7+X^5+X^4+X^2+X^1+X^0 ,用二进制表示为0x4c11db7。在具体实现中,这个多项式的概念并不直接涉及,而是专注于基于二进制的计算过程。CRC32常用于键值数据库,以4字节大小保存校验和。

CRC的数据开头添加0不影响结果,所以会预设置一个余数初始值0XFFFFFFFF,和结果异或值0XFFFFFFFF,这个值在其他CRC类型中可能是别的值。
余数初始值就是余数的一个初始值,结果异或值就是在计算完余数之后要再用结果异或值进行一次异或运算。

CRC一般的实现有一个反向bit顺序,这个和一些硬件的传输顺序有关系,有些硬件设备是把最低位的bit先进行传输,同时在软件层面代码的实现上会更简单一点,相当于我们把最低位当成了最高位,在一些教科书中可能不会明确这个具体实现上的差异。逆向计算的时候,0x4c11db7的逆向表示是0xedb88320。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
pub fn crc32(data: &[u8]) -> u32 {
// 余数预设值
let mut crc: u32 = 0xffffffff;
for byte in data {
crc ^= u32::from(*byte);
for _ in 0..8 {
// 如果高位是1则进行模二除法(异或)
if crc & 1 == 1 {
crc = (crc >> 1) ^ 0xedb88320;
} else {
//如果高位是0则进行
crc >>= 1;
}
}
}
// 余数异或值
crc ^ 0xffffffff
}

当然这个实现是最简单的版本,每次只进行1bit的位移。

CRC32有一个查表法的快速计算方法。CRC32的查表法是一种用于加速计算的优化方式,通过预计算并存储查表,可以显著提高CRC的计算速度。这种优化的核心思想是将CRC的计算过程中的每个字节或更小颗粒度的数据提前计算并存储在一个查表中,以便在实际计算中直接查表获取结果,而不需要逐位进行运算。
如果是一个字节则要保存一张256大小的除法表。

在实际应用中,可以按照字节、4位或16位的颗粒度进行预计算并存成表。这样,每次计算CRC时,只需按颗粒度查表,从而大幅减少了计算的复杂度,提高了效率。

除了查表法,CRC32的另一种优化方式是通过指令级优化,主要依赖于SSE(Streaming SIMD Extensions)和PCLMULQDQ指令集。

SSE指令集:SSE是Intel引入的一组指令,用于执行单指令多数据(SIMD)操作。通过使用SSE指令,可以同时对多个数据进行相同的操作,从而提高并行计算能力,加速CRC32的计算过程。

PCLMULQDQ指令集:PCLMULQDQ指令集是Intel和AMD的x86-64架构中的一组指令,用于执行多项式乘法。CRC32的计算可以看作是多项式乘法,因此使用PCLMULQDQ指令集可以更高效地执行这一计算过程。

通过结合SSE和PCLMULQDQ指令集,可以在硬件层面上进一步提升CRC32的计算性能,使其更适用于高性能计算和大规模数据处理。一些快速计算crc的库基本是运用了

VARINT

Varint使用7个比特的组合来表示整数的值,其中每个字节的最高位用于指示是否有更多的字节。如果最高位为1,表示后面还有一个字节;如果最高位为0,表示这是最后一个字节。这种设计使得解码过程相对简单,而且可以高效地处理不同大小的整数。
下面是一个转化varint为u64的例子,这里定义了一个Varint的trait,使用这个trait是为了展示Rust的Trait的一个特点:可以
给外部对象定义方法。这样通过use Varint,u64就会拥有to_varint_bytes的方法。
补充说明:Rust不允许给外部对象定义外部的Trait的实现,如果可以的话相当于可以篡改外部模块的实现了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
pub trait Varint {
// Define your trait methods here
fn to_varint_bytes(self) -> Vec<u8>;
fn from_varint_bytes(data: &[u8]) -> (Self, usize)
where
Self: Sized;
}

impl Varint for u64 {
fn to_varint_bytes(self) -> Vec<u8> {
let mut value = self;
let mut bytes = Vec::new();

loop {
// 0b是表示binary的表示方法
let mut byte = (value & 0b0111_1111) as u8;
value >>= 7;

if value != 0 {
byte |= 0b1000_0000;
}

bytes.push(byte);

if value == 0 {
break;
}
}

bytes
}

fn from_varint_bytes(data: &[u8]) -> (Self, usize) {
let mut value = 0;
let mut shift = 0;
let mut bytes_read = 0;

for &byte in data {
value |= ((byte & 0b0111_1111) as u64) << shift;
shift += 7;
bytes_read += 1;

if byte & 0b1000_0000 == 0 {
break;
}
}

(value, bytes_read)
}
}

内部Key

存储于数据库中的Key携带了额外信息:Sequence和Meta。
Sequence是一个单调递增的序列值,meta用于保存key的类型比如是一个用于标记删除的key。
其结构如下:

| User key (string) | sequence number (7 bytes) | meta (1 byte) |

Rust定义如下:

1
2
3
4
5
6
pub struct InternalKey {
pub user_key: Vec<u8>,
trailer: u64,
}
const INTERNAL_KEY_SEQ_NUM_MAX: u64 = (1 << 56) - 1;

编解码的方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

fn make_trailer(seq: u64, kind: InternalKeyKind) -> u64 {
(seq << 8) | (kind as u64)
}

pub fn decode(mut encoded_key: Vec<u8>) -> Option<InternalKey> {
if encoded_key.len() >= 8 {
let trailer =
u64::from_le_bytes(encoded_key[encoded_key.len() - 8..].try_into().unwrap());
encoded_key.resize(encoded_key.len() - 8, 0);
return Some(InternalKey {
user_key: encoded_key,
trailer,
});
}
None
}
pub fn encode(self) -> Vec<u8> {
let mut encoded_key = self.user_key;
encoded_key.extend_from_slice(&self.trailer.to_le_bytes());
encoded_key
}

trailer是一个u64,高位56bit保存sequence number,最后一个字节保存meta,目前只保存了key的类型标志。
Key的类型一般有两种一种是插入一种是删除,对于一个key的删除就是插入一个带类型为deletion的key的internal key。
其他的类型也可以扩展这最后的字节里面。

在Rust中如果一个对象实现了Deref,编译器可以自动帮助该对象进行引用的类型转换,这样就是为什么函数参数一般用
&str而不是用&String,因为&String可以被编译器自动转换成&str。为了让编译器让&InternalKey可以自动转换成
&[u8],我们实现如下的DerefTrait。标准库中的PathPathBufStringstr[u8]的引用都是
实现了彼此的Deref所以看到这些类型和函数的入参不一致的时候可以看看是不是有这个规则被编译器自动解引用了。

1
2
3
4
5
6
7
8
// Implementing Deref for InternalKey
impl Deref for InternalKey {
type Target = [u8];

fn deref(&self) -> &Self::Target {
&self.content
}
}

在如下的代码中会方便很多,尽管参数是&[u8]但还是依旧可以用&InternalKey作为参数。

1
2
3
4
5
6
7
8
fn takes_slice(slice: &[u8]) {
println!("{:?}", slice);
}

fn main() {
let key = InternalKey::new(b"123");
takes_slice(&key); // Now this will implicitly call deref()
}

但有些看起来可以自动deref的其实不行例如:&u64&usize,因为usize不一定是64位的所以标准库没有实现AsDeref。

Sequence Number

我们的数据库目前没有完整的事务,但是提供了一定的读的一致性视图和批量写入的原子性,
Sequence Number作为Key的一部分保存了对应的版本信息。
Sequence Number是为了实现Snapshot和原子的批量而依赖的,类似多版本控制的版本,是一个单增的序号。
对于一个读,特别是遍历的时候,会不读取比当前snapshot的sequence大的key,从而保证读的试图的一致性。
在批量写的情况中也是类似的。例如下面的例子来自于LevelDB的文档:将key1的值移动到key2,如果put key2成功之后在del key1之前失败了,
那么就会存在两个key存储了同一个值。

1
2
3
value = get key1
put key2 = value
del key1

通过原子的批量写可以避免这个问题

1
2
3
4
5
value = get key1
wb = new write batch
wb.del key1
wb.put key2 = value
wb.write

只有全部写入以后seq才能增加,不然在数据库重新启动以后发现了大于commited的seq就会放弃重放这些数据。

原子批量写也可以得益于批量写的能力增加写入的带宽。

WriteBatch

WriteBatch将多个写合并成一个写来实现原子性,格式如下。

开头是一个被所有entry共享的sequence,之后跟着一个u32的计数器。

1
2
3
+-------------+------------+--- ... ---+
| SeqNum (8B) | Count (4B) | Entries |
+-------------+------------+--- ... ---+

每个entry的开头是一个byte的Kind(上文提到的Put,Del等类型标记),然后是一个或者两个varbytes,就是常见的varuint32的长度和对应N个字节的数据,取决于类型,例如删除就只有一个。

1
2
3
+-----------+-----------------+-------------------+
| Kind (1B) | Key (varstring) | Value (varstring) |
+-----------+-----------------+-------------------+

Rust的实现如下:

1
2
3
pub struct Batch {
entries: Vec<u8>,
}

设置count和sequence号,

1
2
3
4
5
6
7
fn set_count(&mut self, count: u32) {
self.entries[8..12].copy_from_slice(&count.to_le_bytes());
}

pub fn set_seqnum(&mut self, seqnum: u64) {
self.entries[..8].copy_from_slice(&seqnum.to_le_bytes());
}

put操作,del操作也是类似的只是没有value,并且meta是类型不是Value而是Deletion。

1
2
3
4
5
6
7
8
9
10
pub fn put(&mut self, key: &[u8], value: &[u8]) {
self.entries.push(InternalKeyKind::Value as u8);
self.entries
.extend_from_slice(&(key.len() as u32).to_varint_bytes());
self.entries.extend_from_slice(key);
self.entries
.extend_from_slice(&(value.len() as u32).to_varint_bytes());
self.entries.extend_from_slice(value);
self.set_count(self.count() + 1);
}

在WAL中保存WriteBatch

Snapshot

Snapshot提供一致性的读视图,本质是一个sequence生成器和管理器,任何大于snapshot的key都不会被读取,从而保证读的一致性(特别是在遍历的时候)。在获取的时候会使用sequence来搜索,小于该sequence的key不会被搜索到。
如果是墓碑值也会过滤掉。

1
2
3
4
5
6
7
8
9
10
11
12
13
fn get_internal(&self, key: &[u8], seq: u64) -> Option<Vec<u8>> {
let lookup = InternalKey::new(key, seq, InternalKeyKind::Value);
if let Some((internal_key, value)) = self.mt.get_ge(&lookup.encode()) {
let internal_key = InternalKey::decode(internal_key).unwrap();
if internal_key.user_key == key
// not deleted
&& internal_key.trailer & 0xff != InternalKeyKind::Deletion as u64
{
return Some(value);
}
}
None
}

内部比较器

InternalKey可以重载比较符,在一些场景下比较方便。
但是一些场景我们需要对&[u8]进行比较,这个时候重载比较符是不被允许的。
这样我们定义一个额外的用于比较相关的数据结构。

可以看到上面我们使用了get_ge代表搜索大于等于这个internal key的key。
为了实现大的sequence不被搜索到我们希望seq是降序排序的。
由于user key按升序排列,但希望seq能够按“降序”排列,我们需要定义一个内部比较器,
user key 升序但是当user key 相同时seq大的更“小”。
这样我们的搜索结构就能保持不变,只要寻找的时候找到第一个“大等于” user key,seq 的internal key就可以。
在MemTable中的SkipList将Greater和Equal合并。
我们要给memtable加上get_ge的方法,寻找第一个大于等于key的函数,SkipList的搜索过程中
要多加一个条件,当level为0时,还没有找到相等的值,就返回当前的较大值。

1
2
3
4
5
6
7
std::cmp::Ordering::Greater => {
// no equal key, just return the first greater key.
if l == 0 {
return Some((next.borrow().key.clone(), next.borrow().value.clone()));
}
break;
}

自定义比较函数,使用静态结构体保存函数集合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
use std::{
cmp::{self, min, Ordering},
mem,
};

#[derive(Clone, Copy)]
pub struct Comparator {
pub cmp: fn(&[u8], &[u8]) -> Ordering,
}

pub const INTERNAL_COMPARATOR: Comparator = Comparator { cmp: internal_cmp };

// user key + seno + meta
pub(super) fn internal_cmp(a: &[u8], b: &[u8]) -> Ordering {
assert!(a.len() >= 8);
assert!(b.len() >= 8);
match &a[0..a.len() - 8].cmp(&b[0..b.len() - 8]) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
let a_trailer = u64::from_le_bytes(a[a.len() - 8..].try_into().unwrap());
let b_trailer = u64::from_le_bytes(b[b.len() - 8..].try_into().unwrap());
// revsersed order
b_trailer.cmp(&a_trailer)
}
}
}

pub const BYTEWISE_COMPARATOR: Comparator = Comparator { cmp: cmp::Ord::cmp };

一个是默认的字节序比较器,一个是根据seq做降序比较的比较器,最后的一个字节是type,但是
seq递增的,所以一起比较也不会影响大小,写成静态的结构体来作为相关函数的静态工具集合,
后面还会扩展对应的方法,所以定义了两个Comparator

至于说为什么不定义成InternalKey的方法是因为,我们的比较
大部分情况下是对其他字节序列做比较,因为InternalKey有所有权,
转换成InternalKey要进行一次数据拷贝。

我们可以类似这样定义一个持有引用的对象然后给这种对象增加方法,
因为区别不是很大就还是按照上面的方式实现,没有那么面向对象。

InternalComparator是包含UserComparator的这样就可以支持用户自定义的排序排序方式。

1
2
3
4
5
type InternalKeyRef = &[u8];

impl InternalKeyRef {
cmp(self, other: Other)...
}

SSTable

SSTable(Sorted String Table)是用于各种键值数据库(如LevelDB、RocksDB和Cassandra)的基本数据结构。它旨在高效地存储和检索键值对,同时保持不变性、键有序和基于磁盘存储的原则。
SSTable是一个不可变对象,这和Rust的默认不可变属性很类似。

SSTable的内容通常由以下部分组成:

  • 排序键值数据: SSTable主要包含按特定顺序排序的键值对,以实现键的范围查询和键的单一检索。其中,值和键通常为字符串或字节序列。
  • 数据块: 数据块是SSTable中存储一系列键值对的部分,专门设计用于高效读取操作。每个块可以包含固定数量的键值对,并且通常会进行压缩以节省磁盘空间。
  • 块缓存:块缓存是一个组件,旨在通过在内存中存储频繁访问的SSTable块来提高读取性能。
  • 元数据块: SSTable通常还包括有关其自身基本信息的元数据,在文件版本、结构等方面提供帮助以正确管理和读取SSTable属性等详细信息。
  • 索引块: 索引模块在快速数据检索方面起着关键作用,其中包含元数据和与之相关联的键引用信息,能够快速获取从特定文件偏移量到相应数据块之间映射关系,并支持高效随机访问特定键值对,索引块是一种元数据块。
  • 布隆过滤器: 某些情况下,SSTable可能还会附带Bloom过滤器作为概率型数据结构来判断是否存在特定键,在查找操作期间允许系统跳过不必要I/O请求并减少磁盘读取次数。
  • 压缩:SSTable中的压缩涉及到应用算法来减少数据块的大小。常用的压缩算法有Snappy、LZ4和zlib。
  • 墓碑: 在支持删除功能的数据库中,SSTable可能会包括墓碑来标记已删除的密钥,并确保这些被删除密钥仍然被考虑进去以保持一致性。
  • 文件格式: SSTable内容采用特定文件格式进行组织和编码,在不同数据库系统之间可能存在差异;该设计旨在实现高效读写操作及合并重叠SSTable、有效管理磁盘空间等压缩过程。
  • 校验和: 一些SSTable可能包含校验和或哈希值,在读取操作期间验证数据完整性;校验和有助于确保存储或传输过程中没有损坏。

RocksDB中的SSTable的格式如下:

LevelDB是每个block有一个专有的filter,但考虑到我们的sst文件的filter并不大,所以参考RocksDB的默认实现只用一个full filter。

每个SSTable的block的最大的大小是典型的和page大小一致的4K大小,在一般的数据库或者存储系统中都会使用这个大小,因为
这个大小是page cache的大小,和page cache大小对齐(当然也可以倍数)能够充分利用操作系统的 page cache。

目前我们先实现一个只有一个block的SSTable,规定SSTable的第一层只能是4K,先实现block内部的数据结构。后面我们可以扩展
多个block的SSTable。

我们希望Table有搜索的能力(使用于查找某个键的时候),也有遍历的能力(使用于归并的时候)。

搜索基于二分查找,读取每个restart的开头key作为比较的key,二分查找对应的restart(该restart是key所在的区间)然后遍历restart查找该key。

Data Block构建器

BlockBuilder对key的存储是前缀压缩的,对于有序的字符串来讲,这能极大的减少存储空间。但是却增加了查找的时间复杂度,为了兼顾查找效率,每隔K个key,leveldb就不使用前缀压缩,而是存储整个key,这就是重启点(restartpoint)。重启点不依赖之前的前缀读取完整的key,可以作为二分查找的分界点,合适的间隔是对压缩效率和查询效率的平衡。

1
2
3
4
5
6
<beginning_of_file>
[data block 1]
[meta block 1: filter block]
[metaindex block]
[Footer]
<end_of_file>

其中一个block由多个restart构成。

1
2
3
restart 1
restart 2
restarts_length

每个restart由一个header和一对kv构成,plen表示和之前的key的共享长度的部分可以省略存储的空间,klen表示不共享的长度部分,
value表示值的长度。

1
2
3
4
5
plen
klen
vlen
key
value

作为可选项,可以对block进行压缩,所以block后会追加一个字节表示压缩算法再追加一个crc32的校验值,总共要多5个字节。
如果设置了这些可选项,那block的最终格式是:block data | type(1B) | crc32(4B)

压缩

Snappy

Snappy 是由 Google 开发的一种快速数据压缩和解压缩算法。
它专注于提供较高的压缩速度和相对较快的解压速度,适用于需要在低延迟环境中传输大量数据的应用场景。
Snappy 不是通用的无损压缩算法,因此它可能不适用于所有类型的数据。

Snappy 的压缩算法基于一系列的变长编码和字典压缩。
它使用两种主要类型的标记:字面值标记和复制标记。
字面值标记用于表示原始数据的一部分,而复制标记用于表示先前出现过的数据块的重复。
这使得 Snappy 在处理一些特定模式的数据时能够取得很好的压缩效果。

Data Block读取

  1. Block 结构

    • Block 结构包含两个字段,content 是块的二进制内容,restarts 是存储重启点偏移的数组。
  2. Block::new 方法

    • 通过 Read 实例读取整个块的内容。
    • 解析末尾的重启点数量,获取重启点偏移数组。
    • 截断块的内容,去除重启点信息,返回 Block 实例。
  3. Block::restart_iter 方法

    • 根据给定的重启点索引,初始化 RestartIterator 实例,用于迭代重启点范围内的条目。
    • 计算重启点的起始和结束位置。
  4. Block::iter 方法

    • 返回 BlockIterator 实例,用于迭代整个块的内容。
    • 将块的内容、重启点数组、以及迭代状态传递给 BlockIterator
  5. RestartIterator 迭代过程

    • 迭代器按顺序遍历重启点范围内的条目。
    • 每个条目由键和值组成,通过解析块内容的头部信息获取长度信息,依次读取键和值。
  6. BlockIterator 迭代过程

    • 迭代器按顺序遍历整个块的内容。
    • 对于每个条目,解析块内容的头部信息,获取键和值的长度,依次读取键和值。
  7. Block::get_ge 方法

    • 使用二分查找在块的重启点中找到第一个大于等于给定键的位置。
    • 通过迭代器查找该位置对应的键值对,返回找到的值。
  8. Block::get 方法

    • 使用二分查找在块的重启点中找到等于给定键的位置。
    • 通过迭代器查找该位置对应的键值对,返回找到的值。

总体来说,块的读取过程通过迭代器实现对块内容的顺序访问,而二分查找用于在重启点中快速定位目标位置,提高查找效率。

统一Filter Block构建

分块的Filter Block构建

Full Filter适合SSTable的key比较少的情况,每个SSTable用一个full filter即可,如果SSTable的key较多的情况下,full filter一次性要load太多
不太合适,所以可以按照每个block的规模去创建block级别的fitler。
LevelDB的设计是用base划分filter block,只要data block落在一个base的区间内那么data block的filter负责人就是对应base的filter。
base一般是11也就是2048,举例来说,如果两个block的开始offset都小于2048那么都由第0个filter来管理。
block_offset / base = filter index. filter的结构如下。一个filter可以管理多个block,一个block很大的话会有空的filter,只要这个filter在block范围内。
空的filter的offsets是和之前的offset一样。这里的空的filter和EmptyFilterPolicy又是两回事了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
[filter 0]
[filter 1]
[filter 2]
...
[filter N-1]

[offset of filter 0] : 4 bytes
[offset of filter 1] : 4 bytes
[offset of filter 2] : 4 bytes
...
[offset of filter N-1] : 4 bytes

[offset of beginning of offset array] : 4 bytes
lg(base) : 1 byte

这个设计在RocksDB里面被废弃了而采用了完整Filter Block。

布隆过滤器

LevelDB BloomFilter
Double Hash
Bloom Filter是一种空间效率极高的概率型数据结构,用于判断一个元素是否在一个集合中。它可能会产生误判,即判断一个不存在的元素存在,但不会误判存在的元素。

Bloom Filter的基本原理是,当一个元素被插入集合时,通过K个哈希函数将这个元素哈希成K个位置,然后将这些位置的位都设为1。检查一个元素是否在集合中时,通过同样的哈希函数找到K个位置,如果任何一个位置的位为0,则元素一定不在集合中;如果所有位置的位都为1,则元素可能在集合中。

在LevelDB中,Bloom Filter被用作SSTable的一部分,用于减少不必要的磁盘读取。当查找一个键时,首先使用Bloom Filter判断这个键是否可能在SSTable中,如果可能在,则进行磁盘读取;否则,直接跳过这个SSTable,从而减少了不必要的磁盘读取。

Bloom Filter的优点是空间效率和查询时间都极高,特别适合于元素数量巨大,但内存空间有限的场景。缺点是存在一定的误判率。

Bloom Filter本身是个bit vector,最后一个字节保存K的值。

在LevelDB的Bloom Filter中,使用了一种称为双重哈希(Double Hashing)的技术。这种技术的主要目的是为了解决哈希冲突,即当两个不同的输入产生相同的哈希值时,如何处理。

双重哈希的基本思想是,当哈希冲突发生时,不是简单地在哈希表中寻找下一个空闲位置,而是使用第二个哈希函数来确定探测序列。这个第二个哈希函数会根据输入的键值生成一个新的哈希值,这个新的哈希值会与原始的哈希值结合在一起,用于确定在哈希表中的位置。

在LevelDB的Bloom Filter中,双重哈希的实现方式是,首先使用一个哈希函数将键值哈希到Bloom Filter的某个位置,然后使用第二个哈希函数生成一个新的哈希值,这个新的哈希值用于确定在Bloom Filter中的第二个位置。这样,每个键值在Bloom Filter中都会对应两个位置,大大降低了哈希冲突的可能性,从而提高了Bloom Filter的效率和准确性。

在Bloom Filter中,K值表示我们使用的哈希函数的数量。如果K值大于2,我们将对每个插入的元素应用K个哈希函数,然后在Bloom Filter中设置对应的K个位置。

以下是一个简单的例子,假设我们有一个空的Bloom Filter,长度为m,和3个哈希函数(即K=3)。当我们插入一个元素时,我们将对这个元素应用这3个哈希函数,得到3个哈希值。然后,我们将这3个哈希值对m取模,得到在Bloom Filter中的3个位置,然后将这3个位置的位都设为1。

当我们要检查一个元素是否在集合中时,我们也会对这个元素应用这3个哈希函数,得到3个哈希值,然后查看Bloom Filter中对应的3个位置。如果这3个位置的位都为1,那么我们认为这个元素可能在集合中;如果这3个位置中有任何一个位为0,那么这个元素肯定不在集合中。

需要注意的是,随着K值的增大,误判率会降低,但是插入和查询的时间复杂度也会增大。因此,选择合适的K值是一个需要权衡的问题。

如果你有很多点查找操作(例如Get()),那么Bloom Filter可以帮助加速这些操作,相反,如果你的大部分操作是范围扫描(例如Scan()),那么Bloom Filter就没有帮助了。因为Range两端的两个key可能不存在,这样过滤器就不会生效了,大部分情况下都退化成了一次搜索大于左边界的键和从当前位置遍历直到超过右边界。

murmurhash有比较好的平衡分布的特性,计算速度快也比较简单,LevelDB使用的是最简单的murmur1。他之所以叫murmur的原因是因为他的算法就是两次乘法(multiply)和右移操作(Right shift),在版本三种这个R其实是(Rotate Left)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
fn murmur1(bytes: &[u8], seed: u32) -> u32 {
let m: u32 = 0xc6a4a793;
let r = 24;
let mut h: u32 = 0;
// seed ^ ( n * m )
// may overflow
let mut h: u32 = seed ^ (bytes.len() as u32).wrapping_mul(m);
for chunk in bytes.chunks(4) {
if chunk.len() == 4 {
let w = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
// may overflow
h = h.wrapping_add(w);
h = h.wrapping_mul(m);
h ^= h >> 16;
} else {
for (i, &b) in chunk.iter().enumerate() {
h += u32::from(b) << (i * 8);
}
// may overflow
h = h.wrapping_mul(m);
h ^= h >> r;
}
}
h
}

LevelDB使用Double Hashing模拟多个哈希函数。第一个函数是一个类似murmur的hash函数,而第二个函数则是一个将后15bit和前17bit兑换的简单函数。数据结构包含一个k用来标记hash函数的个数,bits用来保存bit数组。

1
2
3
4
pub struct BloomFilter {
bits: Vec<u8>,
k: usize,
}

Index Block

Index Block 的 key,按照RocksDB的Wiki,
Index Block 的 key >= 当前block,小于下一个block。

经历了几次优化比如把 first encode 到value中,这次就不做这个实现。

理论上用 last key 就可以,但是为了减小大小可以优化一下选择一个较小的key。

比如 [0,1,2] 和 [0,2,3] 当中 [0,1,2]和[0,2]都满足要求。

shortestseperator

shortestseperator是比较器的一部分,在这里作为index block的切分作用,所以在这里做介绍。
shortest_separator的功能:找到一个最字节串介于两个key或者边界之间。
首先找到公共前缀,如果前缀相等那么start就是分割者,如果不想等就顺序找到第一个可以+1的字节(小于255),然后
从那个字节截断就是最小的分割者。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
fn shortest_separator(start: &mut Vec<u8>, limit: &[u8]) {
// Iterate over common prefix of start and limit
let min_length = min(start.len(), limit.len());
let mut diff_index = 0;

while diff_index < min_length && start[diff_index] == limit[diff_index] {
diff_index += 1;
}

// Find the first differing byte
if diff_index < min_length {
let diff_byte = start[diff_index];
if diff_byte < 255 && diff_byte + 1 < limit[diff_index] {
// Increment the differing byte
start[diff_index] += 1;
// Remove the rest of the vector to make it shorter
start.resize(diff_index + 1, 0);
}
} // Do not shorten if one string is a prefix of the other
}

如果是内部key的话,如果分割者比sart短的话,就要完善这个内部ke的构成,我们需要补上一个
MAX_SEQUENCE_NUMBER << 8 | ValueType::TypeForSeek as u64,这样就可以保证分割者比start大。
这里的TypeForSeek表示这个key并没有代表使用的key,而是在index block中起索引作用的key。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fn internal_shortest_separator(start: &mut Vec<u8>, limit: &[u8]) {
assert!(start.len() >= 8);
assert!(limit.len() >= 8);
let l = &start[0..start.len() - 8];
let j = &limit[0..limit.len() - 8];
let mut tmp = l.to_vec();
(UserKeyComparator.separator)(&mut tmp, j);
if tmp.len() < l.len() {
let pack = MAX_SEQUENCE_NUMBER << 8 | ValueType::TypeForSeek as u64;
tmp.extend_from_slice(pack.to_le_bytes().as_ref());
assert!(internal_cmp(start, &tmp) == Ordering::Less);
assert!(internal_cmp(&tmp, limit) == Ordering::Less);
mem::swap(start, &mut tmp);
}
}

TableBuilder

SSTable是不可修改的,只存在创建和删除。

Option

用于管理选项,例如restart_block_interval,block_size等。

TableReader

block的设计有一个好处可以根据block index去读取block。

Block Cache

LRU Cache 需要一个队列和HashMap并且,为了缓解缩的压力也可以对key分片做分段锁。
当需要从BlockHandle获取Block时会先从Block Cache当中获取。

Index 和 Filter Cache

会把level0的index block 和 filer block 缓存起来。
pin_l0_filter_and_index_blocks_in_cache

块缓存用来来缓存未压缩的数据块,它的大小一般来说可以是总内存预算的1/3左右。

LRU Cache

为了加速读取可以将最近使用的一些Table缓存到内存中加速Table的加载。
LRU Cache是个比较经典的数据结构了,但是标准库的双向链表是O(n)的删除。
如果要实现O(1)的链表可以还得自己定一个链表,实现起来比较复杂,我们先暂时容忍一下这个O(n)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
pub struct LRUCache<K, V> {
map: HashMap<K, V>,
order: VecDeque<K>,
capacity: usize,
}

impl<K: Clone + Eq + Hash, V> LRUCache<K, V> {
pub fn new(capacity: usize) -> Self {
LRUCache {
map: HashMap::with_capacity(capacity),
order: VecDeque::with_capacity(capacity),
capacity,
}
}

pub fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &K, f: F) -> &V {
if self.map.contains_key(key) {
self.refresh(key);
self.map.get(key).unwrap()
} else {
self.put(key.clone(), f());
self.map.get(key).unwrap()
}
}

pub fn put(&mut self, key: K, value: V) {
if self.map.contains_key(&key) {
self.refresh(&key);
} else {
if self.map.len() == self.capacity {
if let Some(oldest) = self.order.pop_back() {
self.map.remove(&oldest);
}
}
self.order.push_front(key.clone());
}
self.map.insert(key, value);
}

fn refresh(&mut self, key: &K) {
if let Some(position) = self.order.iter().position(|k| k == key) {
let key = self.order.remove(position).unwrap();
self.order.push_front(key);
}
}
}

Block Cache

TODO

Table Cache

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
pub struct TableCache {
cache: LRUCache<u64, Table<File>>,
opt: Options,
}

impl TableCache {
fn new(opt: Options) -> Self {
Self {
cache: LRUCache::new(100),
opt,
}
}
fn get_table(&mut self, file_num: u64) -> &Table<File> {
self.cache.get_or_insert_with(&file_num, ||{
let file = std::fs::OpenOptions::new()
.read(true)
.open(format!("{}.sst", file_num))
.unwrap();
Table::from_reader(file, self.opt).unwrap()
})
}
}

MANIFEST

MANIFEST也是一个基于record的日志。
版本跟着 sst 文件走,每次插入都会有新版本。
Leveldb每次新生成sstable文件,或者删除sstable文件,都会从一个版本升级成另外一个版本。

相关文件的管理

lognum 用来创建 log 文件。
filenum 用来创建 sst 文件(加载文件时获取)。
level 和文件的映射关系。

当 manifest 文件超过一定大小后会进行压缩,按照全部新建的 edit 进行追加写。
manifest edit 时如果超过大小就会进行压缩。

不需要 version,维护一个 table 的引用计数(rc),当 rc 为 0 时就删除。

需要将 manifest 转换成符合重放的追加写的 edit,这样在重放时可以恢复成原本的 manifest。

合并

合并的基本思想是:数据按多个level组织,目标大小呈指数级增长,例如level0的SSTable的大小是4MB的话,
level1就是8MB,level2就是16MB,以此类推,这里的目标大小相当于level的一个总大小限制。
当一个level的大小超过它的目标大小时,我们选择它的一个或多个文件,并将该文件合并到下一个级别。
合并算法保证了了LSM结构的平衡条件被满足并且可以缩小磁盘占用。合并过程包含的三个主要过程:

minor compaction 和 major compaction

  • 寻找level,选择自身大小除以目标大小的比率最大的level。
  • 在合适的level中挑选合适的文件,比如和下级文件重叠最多的文件,节省空间。
  • 将文件与下一个level的文件(可能多个)进行归并排序组合成新的文件(可能多个),同时之前的墓碑值也可以在合并的时候顺带删除掉。

Major Compaction主要有三种分别为:

  1. Manual Compaction,是人工触发的Compaction,由外部接口调用产生。
  2. Size Compaction,是根据每个level的总文件大小来触发,注意Size Compation的优先级高于Seek Compaction。
  3. Seek Compaction,每个文件的 seek miss 次数都有一个阈值,如果超过了这个阈值,那么认为这个文件需要Compact。

compact_pointer 就是 round robin的标记,如果没找到就从第一个文件开始。LevelDB挑选文件的方式很简单。

读写存储放大

读写存储放大(Read/Write/Space Amplification)是数据库系统中一个重要的性能问题,特别是对于基于磁盘的存储引擎。
读写存储放大是我们在讨论合并策略的时候需要参考的重要指标。
例如:当我们为了减少存储放大把冗余的文件合并的时候就会因为重写多个冗余文件而引入写放大。当不重写文件来减少写放大的时候,又会引入
存储放大和读放大(对应的key可能需要检查多个冗余文件才能找到)。

我们衡量读写存放大的指标是这么计算的:

  • 读放大:读取过程中读取的全部数据/实际读取本身所需要的数据。
  • 写放大:写入过程中写入的全部数据/实际写入本身所需要的数据。
  • 存放大:存储数据占用的磁盘空间/实际存储的数据本身占用的磁盘空间。

数据库的大量优化除了操作执行的时间之外,很重要的性能指标就是这些“放大”,合并当中的很多优化都是围绕它们展开的。

Version

Version中的level的映射可以使用Vec,files是有序的也只需要一个Vec,然后根据file的Smallest和Largest的Key做二分搜索既可以。

Version一个引用计数器,读的时候Ref,读结束的时候UnRef,当没有引用的时候从VersionSet当中删除。
写入的时候是先写到MemTable再写到WAL,和Version没关系,但是Compact的时候有关系。

Verion和Compaction的关系是一对一的。
VersionSet包含多个Version,每个Version是串联的关系,一个Compaction会导致一个新的Version的追加。

LevelDB在合并和打开数据库的时候会删除不需要的文件,这个不需要是通过live_files获取的。
live_files来自所有的live的version,不live的Version会从version_set内删除。

合并涉及到删除操作,但是可能当前有相关文件的读操作还没有结束,所以我们希望在没有引用的情况下再删除改文件。

Version相较于MANIFEST是内存中的对象,每次重启以后只会初始化一个只有一个Version的VersionSet。
每次元数据的改动(SSTable的增加删除和移动等)更新,会导致一次MANIFEST文件的追加写,和内存中的一个新的Version的产生。

1
2
/// VersionSet managed the various versions that are live within a database. A single version
/// contains references to the files on disk as they were at a certain point.

版本保存了当前LSM结构的元信息,版本信息对应的是每个level有哪些文件存在,删除的文件不会立即删除
有可能有访问者正在访问。Version对应的版本是manifest的版本,每个版本对应了不同时刻的levels和sst的变化。

如果有读事务位于旧的文件,那么暂时就不能删除。因此利用引用计数,只要一个Verison还活着,就不允许删除该Verison管理的所有文件。当一个Version生命周期结束,它管理的所有文件的引用计数减1。

当sst文件不再被“活着”的版本引用的时候就可以删除对应的文件了。

有几种情况

  • 重启的时候对log的重放
  • minor compaction
  • major compaction 没有文件合并
  • major compaction 有文件合并

VersionEdit会以追加写的Record形式追加到manifest当中。

这样重放的时候不需要全部堆到level0。
Riak 1.3 版本做了优化,改变了目录结构,对于google 最初版本的LevelDB,所有的文件都在一个目录下,但是Riak 1.3版本引入了子目录, 将不同level的sst 文件放入不同的子目录:

sst_0
sst_1

sst_6

MANIFEST文件和LOG文件一样,只要DB不关闭,这个文件一直在增长。我查看了我一个线上环境,MANIFEST文件已经膨胀到了205MB。

试试上,随着时间的流逝,早期的版本是没有意义的,我们没必要还原所有的版本的情况,我们只需要还原还活着的版本的信息。MANIFEST只有一个机会变小,抛弃早期过时的VersionEdit,给当前的VersionSet来个快照,然后从新的起点开始累加VerisonEdit。这个机会就是重新开启DB。

LevelDB的早期,只要Open DB必然会重新生成MANIFEST,哪怕MANIFEST文件大小比较小,这会给打开DB带来较大的延迟。
MANIFEST 文件列出了构成每个级别的排序表集、相应的键范围以及其他重要元数据。每当数据库重新打开时,都会创建一个新的 MANIFEST 文件(文件名中嵌入一个新编号)。MANIFEST 文件的格式为日志,对服务状态所做的更改(如文件的添加或删除)都会附加到该日志中。

在系统(重新)启动时,最新的清单日志(manifest log)包含了数据库的一致状态。任何后续变化都会被记录到清单日志文件中。
当清单日志文件超过一定大小时,就会创建一个包含状态快照的新清单日志文件。最新的清单文件指针会被更新,文件系统也会同步。一旦成功更新到 CURRENT 文件,多余的清单日志就会被清除。
Badger

Version和FileMetadata是多对1的,当没有Version引用这个FileMetadata的时候就可以删除对应的文件。

也可以通过比较当前Version和数据库目录下的文件删除不需要的,这个在数据库重启的时候比较有用。

VersionEdit

VersionEdit是一次修改的记录,apply一次VersionEdit以后会生成一个新的Version。
VersionEdit使用的是WAL一样的Record。

1
2
3
4
+-------------+------ ......... ----------+
| Record ID | Variable size record data |
+-------------+------ .......... ---------+
<-- byte --->|<-- varies by type -->

例如删除文件 RecordID 为 DeleteFile,保存了删除了哪个level的文件号。

1
2
3
4
5
6
Mark a file as deleted from database.

+-----------------+-------------+--------------+
| kDeletedFile | level | file number |
+-----------------+-------------+--------------+
<-- byte --->|<-- Var32 -->|<-- Var64 -->|
1
2
// The MANIFEST file describes the startup state of the db -- all LSM files and what level they're
// at.

manifest记录了两个主要对象 levels 和 table。
他是一条一条的changeSet,需要replay一次才能重新构建出来。
有哪些level,这些level都有哪些SSTable。

changeSet 在rocksdb里面叫 versionedit

Manifest -> version
Version set
VersionSet
VersionSet 是一个 Version 的集合。
随着数据库状态的变化,LevelDB 内部会不停地生成 VersionEdit——进而产生新的 Version。此时,旧的 Version 可能还在被正在执行的请求使用。所以,同一时刻可能存在多个 Version。
VersionSet 用一个链表将这些 Version 维护起来,每生成一个 Version 就往这个链表尾部插入一个节点(AppendVersion)。

Levels 表示每个level包含的table信息,是一个数组,每个数组的元素是一个 table id的set。
Tables 是一个 map key是id,value包含level 和 checksum的信息。

这样根据 level 可以查找所包含的table,根据table id 可以查找所包含的 level 和 checksum。

level是从零开始连续的所以是个数组。

VersionEdit还有两个指标allowed_seeks和compaction_ptrs,一个代表的是seek操作的次数如果过多某种程度上代表了
文件被搜索的次数过多需要进行合并(TODO:补充其中的逻辑),还有一个代表的是轮转合并的文件指针,只要是为了重启的时候能从
上次的文件之后选择合并的文件,都是和合并相关的指标,在后面和合并相关的章节会展开说明。

VersionBuilder

Version是一个静态只读的数据结构,当我们需要从一个Version转换成另外一个Version的时候,
需要复制当前的Version然后删除和增加对应的文件,我们的files是用Vec保存的,没办法根据file num来做增删。
files的需要保存file num到file metadata的映射需要map,
并且这个file的顺序是有序的要按照边界的key来排序,所以要有一个复合的数据结构。
我们引入一个新的数据结构Version Builder这个Builder会将Version的file拷贝成map然后
应用VersionEdit之后转换再根据file的SmallestKey排序组装回新的Version。

除此之外这个

deleted_file上是一个file num的Set
added_file是一个file metadta的Set

VersionSet

VersionSet是一个双向循环链表,但是在Rust中没有比较好的办法不用unsafe实现
如果你看过 Rust实现链表或者too many list可能会觉得使用Rc加Weak可以,但是双向循环不行,
这个数据结构存在自己指向自己的情况 head.next=head,所以不得不用指针了。

其实不需要这个循环链表,交给Rust的ref去管理就行了。

FileMetadata不是和file一一映射的不能靠file metadata来回收。

链表还是有用的,得查找所有的存活的version。

合并时选择Level

选择文件是一个可以调优的选择,LevelDB用一个单线程的后台进程进行非常简单的RoundRobin轮询挑选。 Compaction操作在key空间中循环执行,详细讲一点就是,对于每个level,我们记录上次compaction的ending key。Level的下一次compaction将选择ending key之后的第一个文件(如果这样的文件不存在,将会跳到key空间的开始)。
也可以使用多线程的进行多文件的合并,每个文件会在一个线程中进行合并,如果有合并参与者则选择下一个文件,
这样就没有读写的冲突,选择的每个并行合并涉及的文件和目标文件互相不干扰,并行合并比较好实现。

这里我们考虑挑选文件的情况:

如果一个文件和下一级的N个文件有交集那么写放到就是近似于N倍,所有的N个文件都会被重写。
如果这个文件的key的分布很松散导致N的数量很大就会导致写放大的问题。
如果key是分布比较均匀这个写放大的效果不会太明显,但实际情况不会如此。
所以RocksDB的一种选项是挑选包含最老更新的文件,这类文件会有最密集的键的分布。或者从另一个角度理解是这个文件在这个level存在的时间够长,相同大小的情况下会保存更多的键,密度也会更大了。
kOldestSmallestSeqFirst 这里的Oldest代表最老,Smallest代表的是时间戳表示最老的更新。

另一种情况是热点键,选择最新更新是最老的文件,则代表的是最“冷”的文件,这样可以减少热点键向下一级移动。
如果一些场景更新比插入的操作更多,具有热点键范围的话可以使用这个选项。
kOldestLargestSeqFirst 这里的Oldest代表最老,Largest代表的是时间戳表示最新的更新。

如果一个文件包含很多墓碑值,它可能会减慢迭代该区域的速度,因为我们仍然需要在合并的时候迭代那些墓碑键。此外,我们越早将墓碑键压缩到最后一层,就能越早回收磁盘空间,因此这有利于提高空间效率。

我们的默认压缩优先级kByCompensatedSize考虑了这种情况。如果文件中的删除(插入墓碑)次数超过插入次数,则更有可能选择该文件进行压缩。删除次数超过插入次数越多,它就越有可能被压缩。这个选项一般是为了解决数据库中有大量的键被删除导致的空间浪费和读放大。

BadgerDB选择的不太一样,是选择最少的overlap,然后这样重写的量最少写放大就会比较少。

我们的实现使用kOldestSmallestSeqFirst这个参数,按照RocksDB的设计。

RocksDB中关于合并的配置

1
2
3
4
cf_options.level_compaction_dynamic_level_bytes = true;
opts.max_background_jobs = 6;
options.bytes_per_sync = 1048576;
options.compaction_pri = kMinOverlappingRatio;

合并是在后台进行的,level0一般都比较小,允许键有重叠,保存在MemTable中。
当MemTable到达一定大小之后,会被转化成SSTable格式刷入磁盘持久化存储。

合并的参与者是两个level,挑选两个level的算法是一个可选项。
选择需要合并的两个level在LevelDB中是level相对越“满”越应该被选为要合并的对象,
也就是如果一个level的大小相对于目标大小的比例total_size / target_size 最大那么就应该被选为合并对象。我们的理想情况应该是两个level合并之后文件的总大小减少最多,或者是让读放大减少。

选定上下两层level以后寻找overlap的部分进行归并排序。
badger 是用 top[0] 和 bottom[:]进行比较
如果是level0的话就是所有table,不是level0的话top就一个文件。

合并完成之后会删除不需要的文件。

动态的Level尺寸

dynamic level size
如果用户数据20GB,那么L6的100GB可能使用不到那么大,这样就会导致写放大。
如果能动态调整level的大小,在初期把l0的直接写入L6。

例如,假设max_bytes_for_level_multiplier=10num_levels=6,以及max_bytes_for_level_base=10MB
level 1到5的目标尺寸开始为:

1
[- - - - 10MB]

因为level 1到4的目标尺寸不适用,所以它们将不会被使用。
直到Level 5的大小增长到超过10MB,例如11MB,我们将基础目标设置为level 4,现在目标看起来是这样的:

1
[- - - 1.1MB 11MB]

随着数据的累积,尺寸目标根据level 5的实际数据进行调整。当level 5的数据达到50MB时,目标是这样的:

1
[- - - 5MB 50MB]

直到level 5的实际大小超过100MB,例如101MB。如果我们继续保持level 4作为基础级别,
它的目标尺寸需要是10.1MB,这不符合目标尺寸范围。
所以现在我们将level 3作为目标尺寸,各级别的目标尺寸看起来是这样的:

1
[- - 1.01MB 10.1MB 101MB]

同样,当level 5进一步增长时,所有级别的目标也增长,如下所示:

1
[- - 5MB 50MB 500MB]

直到level 5超过1000MB并变为1001MB,我们将level 2作为基础级别,
并将级别的目标尺寸设置为:

1
[- 1.001MB 10.01MB 100.1MB 1001MB]

依此类推…

具体可以参考leveled compaction

选择合并文件的边界

当选择文件以后我们需要划分出一个“干净的”边界(亦称clean cutatomic compaction unit)。
我们始终要保持一个版本的恒定性质:level i 的同一个key的版本要高于level i+1 的版本。
下面的描述文件用b来代表,下标作为文件的标号,其中u记作上限,l记作下限,例如,u1代表b1的上限,l2代表b2的下限。

提取来自 |compaction_files| 的最大文件 b1,然后在 |level_files| 中搜索一个文件 b2, user_key(u1) = user_key(l2)。如果找到这样的文件 b2(称为边界文件),则将其添加到 |compaction_files| 中,然后使用这个新的上限再次进行搜索。

之所以这样做的原因是因为:如果存在两个块,b1=(l1, u1) 和 b2=(l2, u2),并且 user_key(u1) = user_key(l2),如果我们压缩了 b1 但没有压缩 b2,那么后续的获取操作将产生不正确的结果(我们保持的恒定的性质:同一个user key,i层的seqnum要高于i+1层的seqnum,并且seqnum降序排序的),因为它将在第 i 层返回来自 b2 的记录而不是来自 b1,也就是b2的seqnum的key把b1中seqnum更大的key给“盖”住了。

在LevelDB当中,文件会”向右“扩展,视图去包含边界上的user key的更低的seqnum(内部key的seqnum是降序排序的)。在其他基于LevelDB的数据库中加入了一些优化。

重叠的分布

但我们讨论写放大的时候就需要考虑重叠的分布,如果一个重叠的分布比较均匀,那么范围越大重叠的文件就会越多。
理想情况下我们会用重叠的大小来代表写放大的程度(这是LevelDB隐含的一个假设),比如一个a-j的范围,分布很均匀,重叠的文件依次是:[a#10,a#1],[b#10,b#1]。
但如果分布不是很均匀 a-j 对应的是 [a#10,e#1],[f#10,f#1],这样写放大的问题就会比较小,而减少了存储放大。

RocksDB针对key的分布不均匀有一些优化,这里我们还是记住一个LevelDB的隐含考量:key的分布是均匀的,重叠越多~文件的写放大越大。

Trivial move

“Trivial move” 是 LevelDB 在执行合并操作时的一个优化策略。如果一个文件在转移到下一层时,并不与下一层的任何文件有重叠,那么这个文件的移动就被认为是 “trivial” 的,即不需要进行复杂的合并操作,可以直接移动文件。这样做的好处是减少了不必要的数据复制,从而提高了整体的性能。另外补充一点:BadgerDB
并没有使用这个策略,因为BadgerDB希望即使是不存在key的合并,但是如果文件里面有删除的key的话,一次复制而不是移动可以减少文件的体积因为合并的过程会把删除的无效key直接删除。

为了利用这个”trivial”移动的优化还需要记录”grandparents”(也就是level i+2)层的重叠文件,LevelDB 通过判断待移动的文件是否与其 “grandparents level” 中的文件有大量重叠来决定是否进行 “trivial move”。如果没有大量重叠,就可以直接移动文件,而不需要合并。这种策略有效地减少了I/O操作和提高了性能,特别是在处理大量数据时。这个大量在LevelDB里面是10*max_file_size(这里的10是每个层级的倍数),默认是25MB。

多路合并

从level到level+1需要对多个文件进行归并排序输出一个新的level+1的文件。我们要定义一个多路的归并器来处理这个问题。

创建一个CompactionState保存TableBuilder和outputfile
保存最小的snapshot,如果没有snapshots就使用最新的sequence。

VersionSet MakeInputIterator

需要的Iterator的数量:

  • level 0 合并成一个 +1 = 2
  • leve > 0 的 size 1= size +1

文件的删除

到目前为止我们的数据库都没有涉及任何的文件删除的能力,所有的老文件都还存在着,当前使用的文件被记录在log当中。
当我们重启数据库的时候我们可以把这些多余的文件删除。

当文件不再被引用的时候既可以删除,或者当Version被释放的时候比较current和目录中的文件集中清楚。
后者在数据库重启的时候比较有用。

FileMetadata

通过重载Drop,可以实现在FileMetadata没有被引用的时候删除这个文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use std::fs;
use std::path::PathBuf;
use std::rc::Rc;
use std::io;

struct TempFile {
path: PathBuf,
}

impl TempFile {
// Creates a new TempFile instance for a given path.
pub fn new(path: PathBuf) -> Rc<Self> {
Rc::new(TempFile { path })
}
}

impl Drop for TempFile {
fn drop(&mut self) {
match fs::remove_file(&self.path) {
Ok(()) => println!("File {:?} was deleted successfully.", self.path),
Err(e) => eprintln!("Error deleting file {:?}: {}", self.path, e),
}
}
}

fn main() -> io::Result<()> {

// Example: Creating a temporary file.
let temp_path = PathBuf::from("my_temp_file.txt");
fs::write(&temp_path, "Temporary file contents")?;

let temp_file = TempFile::new(temp_path.clone());

// Now `temp_file` is an Rc<TempFile>. You can clone `temp_file` to create multiple references.
let temp_file_clone = Rc::clone(&temp_file);

// Both `temp_file` and `temp_file_clone` point to the same `TempFile` instance.
// The TempFile instance will not be dropped (and the file will not be deleted)
// until all Rc references are out of scope.
drop(temp_file);
assert_eq!(fs::metadata("my_temp_file.txt").is_ok(), true);
// Here you could check if the file still exists to verify it was deleted,
// but normally you wouldn't need to do this in your actual application.
drop(temp_file_clone);
assert_eq!(fs::metadata("my_temp_file.txt").is_ok(), false);

Ok(())
}

扩展

Range Delete

在一些将键值数据库作为基础数据库的一些分布式数据库中涉及到删表操作。
Range Delete 会在 sstable 的 meta block 标记。

干净的边界

TODO

事务

TODO

总结

实现LevelDB的过程中,体会到了Rus的很多语法糖和特性,和Go还有Python之间都有相互借鉴,是一个不吝啬引入语言语法复杂性的语言。
在于内存安全方面确实也带了一种新的思考。内容不算完全写完,有的甚至没有实现,有时间在继续补充下去吧。

Prompt缓存可以从两个角度来处理:

基于相似度的外部缓存

一种是对提示词和结果做相似性对比,对结果缓存,这一部分可以在外部来做,例如 langchain的 llm caching。具体实现方法包括:

  1. 向量化存储

    • 将prompt转换为向量表示
    • 使用向量数据库(如FAISS、Milvus等)存储
    • 通过向量相似度检索相近的历史prompt
  2. 模糊匹配

    • 使用编辑距离等算法计算文本相似度
    • 设置相似度阈值进行匹配
    • 返回最相似的历史响应
  3. 缓存策略

    • LRU(最近最少使用)淘汰
    • 时间过期机制
    • 容量限制管理

基于KV Cache的内部优化

另一种是利用 KV Cache 中的交叉注意力机制,复用相同的提示词前缀,这是ChatGPT使用的方法。其工作原理是:

  1. KV Cache机制

    • 存储每个token的Key和Value计算结果
    • 避免重复计算相同前缀
    • 提高推理性能
  2. 增量计算

    • 只对新增的token进行注意力计算
    • 复用已缓存的中间状态
    • 显著减少计算量
  3. 内存管理

    • 自动清理过期缓存
    • 动态调整缓存大小
    • 优化内存使用

在实际应用中,我们可以综合运用这两种缓存方法来优化性能:

  • 对于完全相同或高度相似的prompt,优先使用外部缓存机制
  • 对于部分重叠的prompt,则可以利用KV Cache机制
  • 具体使用哪种策略,需要根据实际场景和资源限制来权衡选择

值得注意的是,KV Cache中的Key和Value都包含了位置编码信息。这意味着要充分发挥prompt缓存的作用,需要确保提示词保持相同的前缀结构。如果提示词的位置发生变化,即使内容相同,对应的KV值也会不同。

具体来说,当两次不同的推理过程中,如果prompt具有相同的提示词前缀,那么这部分的KV计算结果是完全一致的,因此可以直接复用之前推理过程中的KV cache,从而提高推理效率。

最近有一个有趣的论文:通过使用DSL(领域特定语言)来描述prompt结构,可以更精确地控制位置编码。这种方法不仅能够缓存相同的前缀,还支持缓存相同的后缀,同时允许中间部分灵活变动,进一步提升了缓存的效率。但我个人感觉比较难用,等于给本来很灵活的prompt套上了一层结构化的描述语言,这种结构化的语言如果是一些GPT应用的开发有固定模式可能还好,但是通用场景下很难让用户能够用得起来这么专业的描述语言。

像RAG系统中的文档检索结果和固定的记忆上下文,都非常适合作为”提示词前缀”,这样可以更好地利用KV Cache机制。

在设计prompt结构时,我们应该按照内容的变化频率来排序,将越”稳定”的部分放在越前面的位置:

  1. 系统提示词(System Prompt):基本保持不变
  2. 个人记忆(Memory):对特定用户来说相对稳定
  3. RAG检索内容(Context):根据查询动态变化
  4. 对话历史(History):随交互持续更新
  5. 用户输入(User Input):每次都不同

这种由稳定到动态的排序结构可以最大化KV Cache的复用效果:

1
2
3
4
5
System: 你是一个专业的助手。请基于以下上下文回答问题。
Memory: {personal_memory_context}
Context: {retrieved_documents}
History: {chat_history}
User: {user_question}