ggaaooppeenngg

为什么计算机科学是无限的但生命是有限的

Rust从零实现llama

参考Pytorch版本的llama-from-scratch。原文中的RmsNorm的平均值多算了一个维度,这里改成了正确的版本。

首先需要下载TinyShakespeare数据集,这是一个莎士比亚文字数据集。本文档将大致遵循论文的布局,并跳过一些明显的步骤,比如设置虚拟环境和安装依赖项。

我们最终将实现的内容预览:

1
2
3
4
5
6
7
8
9
10
11
12
13
14

println!(generate(llama, MASTER_CONFIG, 500, device)[0])

ZELBETH:
Sey solmenter! 'tis tonguerered if berryishdd, and What his stabe, you, and, but all I pilJefals, mode with,
Vurint as steolated have loven OlD the queen'd refore
Are been, good plmp:

Proforne, wift'es swleen, was no bunderes'd a a quain beath!
Tybell is my gateer stalk smen'd as be matious dazest brink thou
lord
Enves were cIUll, afe and whwas seath This a is, an tale hoice his his onety Meall-tearn not murkawn, fase bettizen'd her,
To belacquesterer? baxewed wupl usweggs yet tall
An

实现过程中可能涉及一些 Rust 的使用方法,与 Python 有所不同,这里不做过多说明,具体的 Rust 语法可以参考其他文档。

迭代工作:从小模块开始,保持确定性,然后逐步构建

  1. 创建所有需要的辅助函数,以便定量测试模型(数据拆分、训练、绘制损失)。
  2. 从论文中挑选出不同的组件,然后逐一实现,边训练边评估。

确保你的层按预期工作

  1. 经常使用 .shape()assert 是你的朋友。
  2. 先在不进行矩阵乘法的情况下计算结果,然后使用 candle 函数使其高效。
  3. 有一个测试来确保你的层是正确的。例如,RoPE 嵌入有一个特定的属性,你可以测试它。对于 Transformer,你可以通过查看注意力矩阵来测试注意力是否正常工作。
  4. 在各种批次、序列和嵌入大小上测试你的层。即使它适用于一种大小,它可能不适用于其他大小,这将在推理时导致问题。

关于 Llama

Llama 是一种基于 Transformer 的语言建模模型。它是一个自回归模型,也称为 CausalModel,模型会将输出中的 token 加入到输入中,不断迭代推理,直到超过上下文长度或遇到停止符。Meta AI 开源了 Llama,并明确表示他们的目标是使模型在推理时更高效,而不是优化训练成本。

接下来,我们将加载库并开始实现。

1
2
3
4
5
6
7
8
9
10
11
12
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::ops::softmax;
use candle_nn::{
embedding, linear, loss, AdamW, Embedding, Init, Linear, Module, Optimizer, ParamsAdamW,
VarBuilder,
};

use core::f32;
use rand::Rng;
use std::collections::HashMap;
use std::fs;
use std::time;

设置数据集

虽然 Llama 在 1.4T 个标记上进行训练,但我们的数据集 TinyShakespeare,即莎士比亚所有作品的集合,大约只有 100 万个字符。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
use std::collections::HashMap;
use std::fs;

fn main() {
// Read the entire content of the file
let lines = fs::read_to_string("./input.txt")
.expect("Failed to read the file");

// Create a sorted set of unique characters
let mut vocab: Vec<char> = lines.chars().collect();
vocab.sort_unstable();
vocab.dedup();

// Create itos and stoi mappings
let itos: HashMap<usize, char> = vocab.iter().enumerate().map(|(i, &ch)| (i, ch)).collect();
let stoi: HashMap<char, usize> = vocab.iter().enumerate().map(|(i, &ch)| (ch, i)).collect();

// Print the first 30 characters of the file
println!("{}", &lines[..30.min(lines.len())]);
}
First Citizen:
Before we proce

他们使用了SentencePiece字节对编码分词器,但我们将只使用一个简单的字符级分词器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use std::collections::HashMap;
use std::fs;
struct Vocab {
itos: HashMap<u32, char>,
stoi: HashMap<char, u32>,
}

impl Vocab {
fn new(itos: HashMap<u32, char>, stoi: HashMap<char, u32>) -> Vocab {
Vocab { itos, stoi }
}
fn decode(&self, ids: &[u32]) -> String {
ids.iter().map(|&id| self.itos[&id]).collect()
}
fn encode(&self, text: &str) -> Vec<u32> {
text.chars().map(|ch| self.stoi[&ch]).collect()
}
fn len(&self) -> usize {
self.itos.len()
}
fn build(lines: &str) -> Self {
// Create a sorted set of unique characters
let mut vocab: Vec<char> = lines.chars().collect();
vocab.sort();
vocab.dedup();

// Create itos and stoi mappings
let itos: HashMap<u32, char> = vocab
.iter()
.enumerate()
.map(|(i, &ch)| (i as u32, ch))
.collect();
let stoi: HashMap<char, u32> = vocab
.iter()
.enumerate()
.map(|(i, &ch)| (ch, i as u32))
.collect();
Self { itos, stoi }
}
}

fn main() {
// Read the entire content of the file
let lines = fs::read_to_string("./input.txt").expect("Failed to read the file");

let vocab = Vocab::build(&lines);
println!("vocab size = {}", vocab.len());
println!("{}", vocab.decode(&vocab.encode("hello")));
}
vocab size = 65
hello

由于数据集较小,我们无需担心内存存储问题。

我们创建了一个 config 对象来存储基本的模型参数。这样可以提高代码的可读性,并且便于修改配置。Rust 是强类型语言,因此所有变量都有明确的类型。

1
2
3
let mut modeConfig = ModelConfig {
vocab_size: vocab.len(),
}
1
2
let dataset = Tensor::from_slice(&vocab.encode(&lines), (lines.len(),), &Device::Cpu).unwrap();
println!("{:?}", dataset.shape());
[1115394]

让我们创建一个方法 get_batches 来生成训练数据和目标的批次。我们将使用相同的方法来生成验证和测试数据,通过 split 参数来控制。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
fn get_batches(
dataset: &Tensor,
split: &str,
batch_size: usize,
context_length: usize,
) -> Result<(Tensor, Tensor)> {
let len_of_dataset = dataset.shape().dim(0).unwrap() as f32;
// 按照 0.8 0.1 0.1 的比例切分训练集, 验证集和测试集
let batch_data = match split {
"val" => &dataset.i((0.8 * len_of_dataset) as usize..(0.9 * len_of_dataset) as usize)?,
"test" => &dataset.i((0.9 * len_of_dataset) as usize..)?,
_ => &dataset.i(..(0.8 * len_of_dataset) as usize)?,
};
// 生成随机index
let mut rng = rand::thread_rng();
let data_len = batch_data.shape().dim(0)?;
let indices: Vec<usize> = (0..batch_size)
.map(|_| rng.gen_range(0..data_len - context_length - 1))
.collect();
let mut x_batches = Vec::with_capacity(batch_size);
let mut y_batches = Vec::with_capacity(batch_size);

for idx in indices {
let x = batch_data.i(idx..(idx + context_length))?;
// y 是 x 后面的一个字符
let y = batch_data.i((idx + 1)..(idx + context_length + 1))?;
x_batches.push(x);
y_batches.push(y);
}
// stack 和 cat 的区别是, stack 是在新的维度上堆叠, cat 是在已有的维度上堆叠
let x_tensor = Tensor::stack(&x_batches, 0)?;
let y_tensor = Tensor::stack(&y_batches, 0)?;
Ok((x_tensor, y_tensor))
}
// in fn main
modeConfig.context_length = 16;
modeConfig.batch_size = 8;
let batch = get_batches(
&dataset,
"train",
modeConfig.batch_size,
modeConfig.context_length,
)?;
println!(
"batch size {}, context_length {}",
batch.0.shape().dim(0)?,
batch.0.shape().dim(1)?
);
for i in 0..modeConfig.batch_size {
println!(
"{:?}, {:?}",
vocab.decode(&batch.0.i(i)?.to_vec1()?),
vocab.decode(&batch.1.i(i)?.to_vec1()?),
);
}

":\nBut, that I'll", "\nBut, that I'll "
"ng?\nWhy, then th", "g?\nWhy, then the"
"s so blind, but ", " so blind, but s"
"thy offices,\nSo ", "hy offices,\nSo r"
"ords, how plainl", "rds, how plainly"
"IET:\nHere's such", "ET:\nHere's such "
"wer\nTo take off ", "er\nTo take off s"
" hurry from the ", "hurry from the f"

实现论文有趣的一点在于,模型“工作”有两个方面:编译(你的张量是否在各层之间匹配)和训练(损失是否下降)。
我们还要定义评估模型的方法。我们希望在定义模型之前就这样做,因为我们希望在训练模型时使用它来评估模型的性能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
fn evaluate_loss(
model: &SimpleBrokenModel,
dataset: &Tensor,
vocab: &Vocab,
config: ModelConfig,
) -> Result<HashMap<String, f32>> {
let mut out = HashMap::new();
for split in ["train", "val"] {
let mut losses = Vec::new();
for _ in 0..10 {
let (xs, ys) = get_batches(&dataset, split, config.batch_size, config.context_length)?;
let (_, loss) = model.forward(&xs, Some(&ys))?;
let loss = loss.unwrap();
losses.push(loss.to_scalar::<f32>()?);
}
let avg_loss = losses.iter().sum::<f32>() / losses.len() as f32;
out.insert(split.to_owned(), avg_loss);
}
Ok(out)
}

设置一个可以工作的简单模型

这是一个带有嵌入的基本前馈神经网络。它是我们将要开始的基础模型,然后我们将逐步替换其部分内容,直到最终得到 Llama 论文中描述的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
struct SimpleBrokenModel {
embedding: Embedding,
mlp: Sequential,
config: ModelConfig,
}

impl SimpleBrokenModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// 潜入层
let embeds = self.embedding.forward(x)?;

// 线性和激活层
let logits = self.mlp.forward(&embeds)?;

// 如果提供了targets就计算loss,不然视为推理,计算logits就可以。
if let Some(targets) = targets {
// 负的似然函数
// -log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// 当 -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}
// VarBuilder是用来构建参数的,我们目前不加载和保存模型参数,但是candle的用法必须基于这个。
// vb.pp 会在参数树中加入参数的前缀,这样可以方便的查看参数的结构。
fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
})
}
}
// in fn main
modeConfig.d_model = 128;
modeConfig.batch_size = 32;
let (xs, ys) = get_batches(
&dataset,
"train",
modeConfig.batch_size,
modeConfig.context_length,
)?;
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let model = SimpleBrokenModel::load(vb, modeConfig)?;
let (logits, loss) = model.forward(&xs, Some(&ys))?;
println!("{:?} {:?}", logits, loss);
let mut params_count: usize = 0;
for (name, var) in varmap.data().lock().unwrap().iter() {
println!("{}: {:?}", name, var.elem_count());
params_count += var.elem_count();
}
println!("params count: {}", params_count);
Tensor[dims 32, 16, 65; f32] Some(Tensor[5.266067; f32])
model.fc2.weight: 8320
model.embed_tokens.weight: 8320
model.fc1.bias: 128
model.fc2.bias: 65
model.fc1.weight: 16384
params count: 33217

在这一点上,我们必须开始关注张量的形状,并让矩阵的维度匹配。查看我们模型定义中的这一行:

1
2
3
4
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;

我们必须调整 logitstargets 张量的形状,以便在比较时它们的维度匹配。我们使用 reshape 方法来实现这一点。
() 参数的意思是“从其他维度推断这个维度”。所以,在这种情况下,我们是在说“将 logitstargets 重新调整为具有相同行数的形状,并使用所需的列数来实现这一点”。这是处理批量数据时的常见模式。

让我们训练我们的 SimpleBrokenModel 以确保梯度流动。在确认这一点之后,我们可以替换它的部分内容以匹配 Llama,再次训练并跟踪我们的进展。在这一点上,我开始记录我的训练运行日志,这样如果我搞砸了,我可以轻松地回到之前的运行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
fn train(
config: ModelConfig,
model: &SimpleBrokenModel,
opt: &mut AdamW,
dataset: &Tensor,
vocab: &Vocab,
) -> Result<()> {
let mut start_time = std::time::Instant::now();
for epoch in 0..config.epochs {
let (xs, ys) = get_batches(&dataset, "train", config.batch_size, config.context_length)?;
let (_, loss) = model.forward(&xs, Some(&ys))?;
opt.backward_step(&loss.unwrap())?;
if epoch % config.log_interval == 0 {
let batch_duration = start_time.elapsed().as_secs_f32();
let loss = evaluate_loss(&model, dataset, vocab, config)?;
let val_loss = loss.get("val").unwrap();
let eta = batch_duration * (config.epochs - epoch) as f32;
let eta = eta.round();
println!(
"Epoch: {epoch} | Loss: {val_loss} | Time: {batch_duration} | ETA in seconds {eta}"
);
start_time = time::Instant::now();
}
}
Ok(())
}
// in fn main
modeConfig.log_interval = 10;
modeConfig.epochs = 100;
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), ParamsAdamW::default())?;
train(modeConfig, &model, &mut opt, &dataset, &vocab)?;
let out = evaluate_loss(&model, &dataset, &vocab, modeConfig);
println!("{:?}", out);
Epoch: 10 | Loss: 3.9159875 | Time: 6.5813394 | ETA in seconds 599
Epoch: 20 | Loss: 3.26492 | Time: 6.3639965 | ETA in seconds 515
Epoch: 30 | Loss: 2.9944448 | Time: 6.3596206 | ETA in seconds 452
Epoch: 40 | Loss: 2.8793342 | Time: 6.357106 | ETA in seconds 388
Epoch: 50 | Loss: 2.7827232 | Time: 6.3562865 | ETA in seconds 324
Epoch: 60 | Loss: 2.764416 | Time: 6.352279 | ETA in seconds 260
Epoch: 70 | Loss: 2.7196321 | Time: 6.356127 | ETA in seconds 197
Epoch: 80 | Loss: 2.7631993 | Time: 6.357493 | ETA in seconds 134
Epoch: 90 | Loss: 2.696882 | Time: 6.358631 | ETA in seconds 70
Epoch: 100 | Loss: 2.670012 | Time: 6.3603354 | ETA in seconds 6
Ok({"train": 2.591057, "val": 2.6625311})
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
fn generate(model: &SimpleBrokenModel, vocab: &Vocab, max_tokens: usize) -> Result<()> {
// batch size 5, initial token = 0
let mut token_ids = Tensor::zeros((5, 1), DType::U32, &Device::Cpu).unwrap();
for _ in 0..max_tokens {
let (logits, _) = model.forward(&token_ids, None)?;
assert!(logits.shape().dims() == [token_ids.dim(0)?, token_ids.dim(1)?, 65]);
let last_step_logits = logits.i((.., logits.dim(1)? - 1))?;
assert!(last_step_logits.shape().dims() == [token_ids.dim(0)?, 65]);
let probs = softmax(&last_step_logits, last_step_logits.dims().len() - 1)?;
assert!(probs.shape().dims() == [token_ids.dim(0)?, 65]);
let next_token = probs.argmax(probs.dims().len() - 1)?;
assert!(next_token.shape().dims() == [token_ids.dim(0)?]);
token_ids = Tensor::cat(&[token_ids, next_token.reshape(((), 1))?], 1)?;
}
let lines = fs::read_to_string("./input.txt").expect("Failed to read the file");
for v in &token_ids.to_vec2()? {
let text = vocab.decode(v);
println!("{}", text);
}
Ok(())
}
// fn in main
generate(&model, &vocab, 10, device)?;
['\nFind!\nD:\nAr t,\nLis sthte o t l',
 '\nAnd ronnot ar\nBE:\nKINRDYOrspr;',
 '\nI t athe momyengthend thanswal',
 '\nFis t bp he\nLacarn.\nA:\nYOMI wi',
 '\nWh ly sck\nB-de pll t\nHERIns ou']

这还算不错,但也不算太好。不过现在我们有了一个可以训练到验证损失的工作模型。因此,我们将在此基础上迭代我们的模型,使其更接近 Llama。

Llama 具体细节

Llama 对原始 Transformer 进行了三项架构修改:

  1. 用于预归一化的 RMSNorm
  2. 旋转嵌入 RoPE
  3. SwiGLU 激活函数

我们将逐一添加每个修改到我们的基础模型,并进行迭代。

RMSNorm

在 Vaswani 2017 中,原始的 Transformer 使用了 BatchNormalization。在 Llama 中,作者使用了 RMSNorm,这是一种在不进行中心化的情况下通过方差来缩放向量的方法。此外,虽然 Vaswani 将归一化应用于注意力层的输出(后归一化),但 Llama 将其应用于输入之前(前归一化)。

这篇文章对于RMSNorm有一个很好的解释。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
pub struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(size: usize, vb: VarBuilder) -> Result<Self> {
Ok(RmsNorm {
scale: vb.get_with_hints(size, "weight", Init::Const(1.))?,
eps: 1e-6,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_sqr = x.sqr()?;
assert!(x_sqr.shape().dims() == x.shape().dims());
let norm_x = (x.mean(D::Minus1)? + self.eps)?.sqrt()?;
assert!(norm_x.shape().dims() == [x.shape().dim(0)?, x.shape().dim(1)?]);
let x_normed = x.broadcast_div(&norm_x.reshape((
norm_x.shape().dim(0)?,
norm_x.shape().dim(1)?,
(),
))?)?;
assert!(x_normed.shape().dims() == x.shape().dims());
let x = (x_normed.broadcast_mul(&self.scale))?;
Ok(x)
}
}

let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);

let rms_norms = RmsNorm::new(2, vb)?;
// (2,3,2)
let batch = Tensor::new(
vec![
vec![vec![1f32, 1f32], vec![1.2f32, 2f32], vec![3f32, 3f32]],
vec![vec![4f32, 43f32], vec![5f32, 5f32], vec![61f32, 6f32]],
],
&Device::Cpu,
)?;
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let out = rms_norms.forward(&batch)?;
Tensor[dims 2, 3, 2; f32]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
struct SimpleBrokenModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
}

impl SimpleBrokenModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// Embedding
let embeds = self.embedding.forward(x)?;
// RMSNorm
let normed_embeds = self.rms_norm.forward(&embeds)?;
// Linear layers and activation
let logits = self.mlp.forward(&normed_embeds)?;

// Calculate loss if targets are provided
if let Some(targets) = targets {
// 负的似然函数
// log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
// println!("{:?}", targets.shape());
// println!("{:?}", logits.shape());
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}
// VarBuilder是用来构建参数的。
fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
})
}
}

Epoch: 10 | Loss: 4.1559505 | Time: 6.779387 | ETA in seconds 617
Epoch: 20 | Loss: 4.14648 | Time: 6.7727704 | ETA in seconds 549
Epoch: 30 | Loss: 4.1364665 | Time: 6.776428 | ETA in seconds 481
Epoch: 40 | Loss: 4.125594 | Time: 6.772582 | ETA in seconds 413
Epoch: 50 | Loss: 4.120083 | Time: 6.7661977 | ETA in seconds 345
Epoch: 60 | Loss: 4.1099877 | Time: 6.760399 | ETA in seconds 277
Epoch: 70 | Loss: 4.0996284 | Time: 6.7623253 | ETA in seconds 210
Epoch: 80 | Loss: 4.0902996 | Time: 6.761824 | ETA in seconds 142
Epoch: 90 | Loss: 4.0833025 | Time: 6.76845 | ETA in seconds 74
Epoch: 100 | Loss: 4.070025 | Time: 6.7624454 | ETA in seconds 7
Ok({"train": 4.072861, "val": 4.0711236})

从这里得到的结果来看,范化以后,模型的表现并没有提升,所以我们需要继续迭代,只是梯度的下降变得比较平滑了。

Rotary Embeddings

RoPE 是一种用于 Transformer 的位置编码方法。在《Attention is All You Need》中,作者提出了两种位置编码方法:学习的和固定的。在 RoPE 中,作者通过旋转嵌入来表示序列中标记的位置,并在每个位置使用不同的旋转角度。

其中的 cos 和 sin 值可以预先计算并缓存,避免重复计算,后续会统一存放在一个缓存结构中。

RoPE 将 hidden_state 中每两个 x 组成的向量与旋转矩阵相乘来实现位置编码。

1
2
3
4
[x0, x1, .... ,xn]
y0 = x0 * cos(theta) - x1 * sin(theta)
y1 = x0 * sin(theta) + x1 * cos(theta)
[y0, y1, ...., yn]

nd_model 的一半。
theta 是一个根据位置得到的固定值,计算公式为:
theta = m / 10000^(2i/n)
其中,m 是在序列中的位置,i 是在 d_model 中的位置。

这个公式的含义是将特征向量中的 x0x1 进行一个固定的旋转,这个旋转不是通过学习得到的,而是预先计算的。它可以用于表示相对位置信息。

1
2
3
```Rust
pos_index = 0, 中的 x0,x1
pos_index = 1, 中的 x0,x1

隔了一个恒定的调度旋转。

freq_cis缓存住提前算好的cos和sin的值,这部分不用重复计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct Cache {
cos: Tensor,
sin: Tensor,
}

impl Cache {
fn new(context_length: usize, n_elem: usize, vb: VarBuilder) -> Result<Cache> {
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), vb.device())?;
let idx_theta = Tensor::arange(0, context_length as u32, vb.device())?
.to_dtype(DType::F32)?
.reshape((context_length, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let freq_cis_real = idx_theta.cos()?;
let freq_cis_imag = idx_theta.sin()?;

let cos = freq_cis_real.reshape((context_length, n_elem / 2, 1))?;
let sin = freq_cis_imag.reshape((context_length, n_elem / 2, 1))?;
Ok(Cache { cos, sin})
}
}

rope计算的时候

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
fn apply_rotary_emb(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
// println!("shape of cache.cos {:?}", cache.cos.shape());
let cos = cache.cos.i(..seq_len)?;
let sin = cache.sin.i(..seq_len)?;
// println!("cos shape {:?}", cos.shape());
let cos = cos.broadcast_as((b_sz, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, n_embd / 2, 1))?;
// println!("broadcast cos shape {:?}", cos.shape());
let x = x.reshape((b_sz, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, n_embd))?;
Ok(rope)
}

Self Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
struct AttentionModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
self_attention: SelfAttention,
cache: Cache,
}

impl AttentionModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {
// 嵌入层
let embeds = self.embedding.forward(x)?;
// 范化层
let normed_embeds = self.rms_norm.forward(&embeds)?;
// 自注意力层
let y = self.self_attention.forward(&normed_embeds, &self.cache)?;
// 线性和激活层
let logits = self.mlp.forward(&normed_embeds)?;

if let Some(targets) = targets {
// 负的似然函数
// -log(x) 越大,loss 越小
// y = [0, 0 , 0, 0, 1, ...,0,0]
// y' = [4, 5, 6, 7, 8, ...,11,12 ]
// 这个 cross_entropy 帮我们做了一个 log softmax
// y' = [0.1, 0.12, 0.13, 0.64, ..., 0,0]
// loss = -log(0.64)
// 例如 -log(q) = 4.17 q = 0.015 大概 1/64,vocab_size = 65,所以基本是在瞎猜。
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let self_attention = SelfAttention::load(vb.pp("model.self_attention"), config)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
self_attention,
cache: Cache::new(config.context_length, config.d_model, vb)?,
})
}
}

提示:了解训练时张量维度与推理时张量维度的区别。

虽然在训练时,你可以期望张量维度与模型参数紧密匹配,例如 batch.shape = (config['batch_size'], config['context_window'], config['d_model']),但在推理时,你可能需要处理单个示例,例如 batch.shape = (1, 1, config['d_model'])。因此,你需要确保在 forward 传递中进行索引时,使用从输入派生的形状,而不一定是模型参数。

MultiHeadRopeAttention

让我们为这个单一的注意力头设置一个多头注意力层,看看训练时会发生什么。

这里实现的是GQA的注意力头。n_kv_head=1时就是MQA,n_kv_head>1n_kv_head<n_head时就是GQA,n_kv_head=n_head时就是原本的MHA。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
struct MultiHeadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
}

impl MultiHeadAttention {
fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let q_proj = linear(config.d_model, config.d_model, vb.pp("model.q_proj"))?;
let k_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.k_proj"),
)?;
let v_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.v_proj"),
)?;
let o_proj = linear(config.d_model, config.d_model, vb.pp("model.o_proj"))?;
println!(
"MHA config n_head {} n_kv_head {}",
config.n_head, config.n_kv_head
);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: config.n_head,
n_kv_head: config.n_kv_head,
head_dim: config.d_model / config.n_head,
})
}
fn apply_rotary_emb(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = cache.cos.i(..seq_len)?;
let sin = cache.sin.i(..seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
Ok(rope)
}
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;

// 计算 q k v
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;

assert!(n_embd == self.n_head * self.head_dim);

let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;
let v = v.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;

// 对 q 和 k 做位置编码
let q = self.apply_rotary_emb(&q, cache)?;
let k = self.apply_rotary_emb(&k, cache)?;
// 复制成 n_head / n_kv_head 份
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;

// 把 seq_len 和 n_head 交换
// 这转换一下是为了做一个 cat single head 的简单操作
// 相当于 n_head 个的seq_len*seq_len的注意力。
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;

// q*k^T / sqrt(d_k) d_k = d_model
// 这里是 (bs,n_head) 个 (seq_len, seq_len) 的注意力
let attn = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;

// 这里是头内的softmax (seq_len,seq_len)的行总和为1
let attn = softmax(&attn, D::Minus1)?;

// 再乘 * (bs, n_head, seq_len, head_dim) 得到 (bs, n_head)个注意力头对应的加权的v
let y = attn.matmul(&v)?;
// 把 n_head 和 seq_len 交换回来,得到 (bs, seq_len, n_head, head_dim) 然后reshape以后
// 得到 (bs, seq_len, n_head * head_dim) 把头给cat到一起。
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}

fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
}

完整模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
struct AttentionModel {
embedding: Embedding,
mlp: Sequential,
rms_norm: RmsNorm,
config: ModelConfig,
self_attention: MultiHeadAttention,
cache: Cache,
}

impl AttentionModel {
fn forward(&self, x: &Tensor, targets: Option<&Tensor>) -> Result<(Tensor, Option<Tensor>)> {

let embeds = self.embedding.forward(x)?;

let normed_embeds = self.rms_norm.forward(&embeds)?;
let y = self.self_attention.forward(&normed_embeds, &self.cache)?;

let logits = self.mlp.forward(&y)?;

if let Some(targets) = targets {
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

fn load(vb: VarBuilder, config: ModelConfig) -> Result<(Self)> {
let embedding = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let rms_norm = RmsNorm::new(config.d_model, vb.pp("model.rms_norm"))?;
let self_attention = MultiHeadAttention::load(vb.pp("model.multi_head_attention"), config)?;
let mlp = sequential::seq()
.add(linear(
config.d_model,
config.d_model,
vb.push_prefix("model.fc1"),
)?)
.add(Activation::Relu)
.add(linear(
config.d_model,
config.vocab_size,
vb.push_prefix("model.fc2"),
)?);
Ok(Self {
embedding,
mlp,
config,
rms_norm,
self_attention,
cache: Cache::new(config.context_length, config.d_model/config.n_head, vb)?,
})
}
}

1
generate(&model, &vocab, 10, device)?;
['\n\n\n\n\n\n\n\nI\n\nOOOOOOOOOFOOtOOOOOOO',
 '\nIIIIII IIIIIIIIIIIIIIIIIIIIIII',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\naaame',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n']

所以看起来很糟糕。这里发生了什么?让我们通过查看注意力来开始调试。

目前的注意力是没有masked的,任何位置的字符都在关注任何其他位置的字符。
这有什么不好呢?我们试图仅基于之前的标记来预测下一个标记,但这里我们看到模型正在关注之后的标记。
换句话说,模型在作弊,或者从未来泄露信息。这是一个问题,这就是为什么我们需要使用因果掩码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

pub struct Cache {
// ...
// mask 也可以 cached
mask: Tensor,
}
impl Cache {
fn new(context_length: usize, n_elem: usize, vb: VarBuilder) -> Result<Cache> {
// ...
// _ 表示类型由编译器推断,
// 默认的collect 是 [_],但是这个大小是不可变的要编译期间决定
// 所以这里还是要提示编译器要 collect 成 vec.
let mask: Vec<_> = (0..context_length)
.flat_map(|i| (0..context_length).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (context_length, context_length), vb.device())?;
Ok(Cache { cos, sin, mask })
}
}

struct MultiHeadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
}

impl MultiHeadAttention {
fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let q_proj = linear(config.d_model, config.d_model, vb.pp("model.q_proj"))?;
let k_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.k_proj"),
)?;
let v_proj = linear(
config.d_model,
(config.d_model / config.n_head) * config.n_kv_head,
vb.pp("model.v_proj"),
)?;
let o_proj = linear(config.d_model, config.d_model, vb.pp("model.o_proj"))?;
println!(
"MHA config n_head {} n_kv_head {}",
config.n_head, config.n_kv_head
);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: config.n_head,
n_kv_head: config.n_kv_head,
head_dim: config.d_model / config.n_head,
})
}
// ...
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;

// 计算 q k v
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;

assert!(n_embd == self.n_head * self.head_dim);

let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;
let v = v.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?;

// 对 q 和 k 做位置编码
let q = self.apply_rotary_emb(&q, cache)?;
let k = self.apply_rotary_emb(&k, cache)?;
// 复制成 n_head / n_kv_head 份
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;

// println!("q.shape {:?}", q.shape());
// 把 seq_len 和 n_head 交换
// 这转换一下是为了做一个 cat single head 的简单操作
// 相当于 n_head 个的seq_len*seq_len的注意力。
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;

// let tmp = q.matmul(&k.t()?)?;
// 这个结果是有负数的,但是注意力层不会有负数。
// 计算结果出了很多 NaN,感觉应该是范化没做好。
// 后面发现是没有用 sqrt 用了开方,导致负数变成了NaN。
// q*k^T / sqrt(d_k) d_k = d_model
// 这里是 (bs,n_head) 个 (seq_len, seq_len) 的注意力
let attn = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
// 在 softmax 之前,把未来的token位置变为负无穷,这样在softmax之后,这些位置的概率就会变为0
let mask = cache
.mask
.i((..seq_len, ..seq_len))?
.unsqueeze(0)?
.unsqueeze(0)?
.broadcast_as(attn.shape())?;
let on_true =
Tensor::new(f32::NEG_INFINITY, attn.device())?.broadcast_as(mask.shape().dims())?;
let attn = mask.where_cond(&on_true, &attn)?;
// 取一个例子
// 这里是头内的softmax (seq_len,seq_len)的每行 (seq_len) 总和为1
let attn = softmax(&attn, D::Minus1)?;

// 再乘 * (bs, n_head, seq_len, head_dim) 得到 (bs, n_head)个注意力头对应的加权的v
let y = attn.matmul(&v)?;
// 把 n_head 和 seq_len 交换回来,得到 (bs, seq_len, n_head, head_dim) 然后reshape以后
// 得到 (bs, seq_len, n_head * head_dim) 把头给cat到一起。
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}

fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
}

现在,我们可以让注意力激活的上三角部分(对应未来的部分)几乎被归零了。让我们看看训练时会发生什么。

SwiGLU

正如论文中所述,“我们用SwiGLU激活函数替换了ReLU非线性函数……我们使用$\frac{2}{3} 4d$的维度,而不是PaLM中的$4d$。” SwiGLU定义为:
$$
\text{SwiGLU}(x) = \text{Swish}_\beta (xW + b) \otimes (xV + c)
$$
其中$\otimes$是逐元素乘积。Swish函数定义为:
$$
\text{Swish}_\beta(x) = x \sigma(\beta x)
$$
其中$\beta$是一个可学习的参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}

struct SwiGLU {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
}
// 新的 mlp 是三层的带gate
impl SwiGLU {
// silu 的特征是允许有一点点的负数
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// 这里就是 SwiGLU SiLU(W_1 * x) * (W_2 * x) 是 element wise 的,这个可以作为gate门信号
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}

fn load(vb: VarBuilder, cfg: ModelConfig) -> Result<Self> {
let h_size = cfg.d_model;
let i_size = cfg.hidden_dim;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
Ok(Self {
c_fc1,
c_fc2,
c_proj,
})
}
}

一个llama block res 两次,一次在attention之前,一次在swiglu之前。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
struct Block {
rms_1: RmsNorm,
attn: MultiHeadAttention,
rms_2: RmsNorm,
swiglu: SwiGLU,
}

impl Block {
fn forward(&self, x: &Tensor, cache: &Cache) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, cache)? + residual)?;
let residual = &x;
let x = (self.swiglu.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}

fn load(vb: VarBuilder, cfg: ModelConfig) -> Result<Self> {
let attn = MultiHeadAttention::load(vb.pp("self_attn"), cfg)?;
let swiglu = SwiGLU::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::new(cfg.d_model, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::new(cfg.d_model, vb.pp("post_attention_layernorm"))?;
Ok(Self {
rms_1,
attn,
rms_2,
swiglu,
})
}
}

现在,让我们通过创建块来添加多个层的最后完整的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
struct Llama {
embedding: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
cache: Cache,
config: ModelConfig,
}

impl Llama {
pub fn forward(
&self,
x: &Tensor,
targets: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.embedding.forward(x)?;
for block in &self.blocks {
x = block.forward(&x, &self.cache)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;

if let Some(targets) = targets {
let loss = loss::cross_entropy(
&logits.reshape(((), self.config.vocab_size))?,
&targets.reshape(((),))?,
)?;
Ok((logits, Some(loss)))
} else {
Ok((logits, None))
}
}

pub fn load(vb: VarBuilder, config: ModelConfig) -> Result<Self> {
let embed_layer = embedding(
config.vocab_size,
config.d_model,
vb.pp("model.embed_tokens"),
)?;
let lm_head = linear(config.d_model, config.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(config.d_model, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..config.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), config).unwrap())
.collect();
Ok(Self {
embedding: embed_layer,
blocks,
ln_f,
lm_head,
config,
cache: Cache::new(config.context_length, config.d_model / config.n_head, vb)?,
})
}
}

扩展

这里主要是训练的部分,在推理的过程中还涉及到一个比较重要的kv cache。
kv cache主要是缓存自回归过程中的kv,这个kv是不变的,因为这个k 和 v 只和之前的token有关系,所以可以缓存下来,
这样在推理的时候就不用重复计算了。
围绕着prompt产生第一个Token的prefill的阶段和计算完prompt之后的decode阶段也是目前业界比较关注的推理优化的方向。
源代码在这里