参考Pytorch版本的llama-from-scratch 。原文中的RmsNorm的平均值多算了一个维度,这里改成了正确的版本。
首先需要下载TinyShakespeare 数据集,这是一个莎士比亚文字数据集。本文档将大致 遵循论文的布局,并跳过一些明显的步骤,比如设置虚拟环境和安装依赖项。
我们最终将实现的内容预览:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 println! (generate (llama, MASTER_CONFIG, 500 , device)[0 ])ZELBETH: Sey solmenter! 'tis tonguerered if berryishdd, and What his stabe, you, and, but all I pilJefals, mode with, Vurint as steolated have loven OlD the queen'd refore Are been, good plmp: Proforne, wift'es swleen, was no bunderes'd a a quain beath! Tybell is my gateer stalk smen'd as be matious dazest brink thou lord Enves were cIUll, afe and whwas seath This a is, an tale hoice his his onety Meall-tearn not murkawn, fase bettizen'd her, To belacquesterer? baxewed wupl usweggs yet tall An
实现过程中可能涉及一些 Rust 的使用方法,与 Python 有所不同,这里不做过多说明,具体的 Rust 语法可以参考其他文档。
迭代工作:从小模块开始,保持确定性,然后逐步构建
创建所有需要的辅助函数,以便定量测试模型(数据拆分、训练、绘制损失)。
从论文中挑选出不同的组件,然后逐一实现,边训练边评估。
确保你的层按预期工作
经常使用 .shape()
。assert
是你的朋友。
先在不进行矩阵乘法的情况下计算结果,然后使用 candle
函数使其高效。
有一个测试来确保你的层是正确的。例如,RoPE 嵌入有一个特定的属性,你可以测试它。对于 Transformer,你可以通过查看注意力矩阵来测试注意力是否正常工作。
在各种批次、序列和嵌入大小上测试你的层。即使它适用于一种大小,它可能不适用于其他大小,这将在推理时导致问题。
关于 Llama Llama 是一种基于 Transformer 的语言建模模型。它是一个自回归模型,也称为 CausalModel,模型会将输出中的 token 加入到输入中,不断迭代推理,直到超过上下文长度或遇到停止符。Meta AI 开源 了 Llama,并明确表示他们的目标是使模型在推理时更高效,而不是优化训练成本。
接下来,我们将加载库并开始实现。
1 2 3 4 5 6 7 8 9 10 11 12 use candle_core::{DType, Device, IndexOp, Result , Tensor, D};use candle_nn::ops::softmax;use candle_nn::{ embedding, linear, loss, AdamW, Embedding, Init, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, }; use core::f32 ;use rand::Rng;use std::collections::HashMap;use std::fs;use std::time;
设置数据集 虽然 Llama 在 1.4T 个标记上进行训练,但我们的数据集 TinyShakespeare,即莎士比亚所有作品的集合,大约只有 100 万个字符。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 use std::collections::HashMap;use std::fs;fn main () { let lines = fs::read_to_string ("./input.txt" ) .expect ("Failed to read the file" ); let mut vocab : Vec <char > = lines.chars ().collect (); vocab.sort_unstable (); vocab.dedup (); let itos : HashMap<usize , char > = vocab.iter ().enumerate ().map (|(i, &ch)| (i, ch)).collect (); let stoi : HashMap<char , usize > = vocab.iter ().enumerate ().map (|(i, &ch)| (ch, i)).collect (); println! ("{}" , &lines[..30 .min (lines.len ())]); }
First Citizen:
Before we proce
他们使用了SentencePiece 字节对编码分词器,但我们将只使用一个简单的字符级分词器。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 use std::collections::HashMap;use std::fs;struct Vocab { itos: HashMap<u32 , char >, stoi: HashMap<char , u32 >, } impl Vocab { fn new (itos: HashMap<u32 , char >, stoi: HashMap<char , u32 >) -> Vocab { Vocab { itos, stoi } } fn decode (&self , ids: &[u32 ]) -> String { ids.iter ().map (|&id| self .itos[&id]).collect () } fn encode (&self , text: &str ) -> Vec <u32 > { text.chars ().map (|ch| self .stoi[&ch]).collect () } fn len (&self ) -> usize { self .itos.len () } fn build (lines: &str ) -> Self { let mut vocab : Vec <char > = lines.chars ().collect (); vocab.sort (); vocab.dedup (); let itos : HashMap<u32 , char > = vocab .iter () .enumerate () .map (|(i, &ch)| (i as u32 , ch)) .collect (); let stoi : HashMap<char , u32 > = vocab .iter () .enumerate () .map (|(i, &ch)| (ch, i as u32 )) .collect (); Self { itos, stoi } } } fn main () { let lines = fs::read_to_string ("./input.txt" ).expect ("Failed to read the file" ); let vocab = Vocab::build (&lines); println! ("vocab size = {}" , vocab.len ()); println! ("{}" , vocab.decode (&vocab.encode ("hello" ))); }
vocab size = 65
hello
由于数据集较小,我们无需担心内存存储问题。
我们创建了一个 config
对象来存储基本的模型参数。这样可以提高代码的可读性,并且便于修改配置。Rust 是强类型语言,因此所有变量都有明确的类型。
1 2 3 let mut modeConfig = ModelConfig { vocab_size: vocab.len (), }
1 2 let dataset = Tensor::from_slice (&vocab.encode (&lines), (lines.len (),), &Device::Cpu).unwrap ();println! ("{:?}" , dataset.shape ());
[1115394]
让我们创建一个方法 get_batches
来生成训练数据和目标的批次。我们将使用相同的方法来生成验证和测试数据,通过 split
参数来控制。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 fn get_batches ( dataset: &Tensor, split: &str , batch_size: usize , context_length: usize , ) -> Result <(Tensor, Tensor)> { let len_of_dataset = dataset.shape ().dim (0 ).unwrap () as f32 ; let batch_data = match split { "val" => &dataset.i ((0.8 * len_of_dataset) as usize ..(0.9 * len_of_dataset) as usize )?, "test" => &dataset.i ((0.9 * len_of_dataset) as usize ..)?, _ => &dataset.i (..(0.8 * len_of_dataset) as usize )?, }; let mut rng = rand::thread_rng (); let data_len = batch_data.shape ().dim (0 )?; let indices : Vec <usize > = (0 ..batch_size) .map (|_| rng.gen_range (0 ..data_len - context_length - 1 )) .collect (); let mut x_batches = Vec ::with_capacity (batch_size); let mut y_batches = Vec ::with_capacity (batch_size); for idx in indices { let x = batch_data.i (idx..(idx + context_length))?; let y = batch_data.i ((idx + 1 )..(idx + context_length + 1 ))?; x_batches.push (x); y_batches.push (y); } let x_tensor = Tensor::stack (&x_batches, 0 )?; let y_tensor = Tensor::stack (&y_batches, 0 )?; Ok ((x_tensor, y_tensor)) } modeConfig.context_length = 16 ; modeConfig.batch_size = 8 ; let batch = get_batches ( &dataset, "train" , modeConfig.batch_size, modeConfig.context_length, )?; println! ( "batch size {}, context_length {}" , batch.0 .shape ().dim (0 )?, batch.0 .shape ().dim (1 )? ); for i in 0 ..modeConfig.batch_size { println! ( "{:?}, {:?}" , vocab.decode (&batch.0 .i (i)?.to_vec1 ()?), vocab.decode (&batch.1 .i (i)?.to_vec1 ()?), ); }
":\nBut, that I'll", "\nBut, that I'll "
"ng?\nWhy, then th", "g?\nWhy, then the"
"s so blind, but ", " so blind, but s"
"thy offices,\nSo ", "hy offices,\nSo r"
"ords, how plainl", "rds, how plainly"
"IET:\nHere's such", "ET:\nHere's such "
"wer\nTo take off ", "er\nTo take off s"
" hurry from the ", "hurry from the f"
实现论文有趣的一点在于,模型“工作”有两个方面:编译(你的张量是否在各层之间匹配)和训练(损失是否下降)。 我们还要定义评估模型的方法。我们希望在定义模型之前就这样做,因为我们希望在训练模型时使用它来评估模型的性能。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 fn evaluate_loss ( model: &SimpleBrokenModel, dataset: &Tensor, vocab: &Vocab, config: ModelConfig, ) -> Result <HashMap<String , f32 >> { let mut out = HashMap::new (); for split in ["train" , "val" ] { let mut losses = Vec ::new (); for _ in 0 ..10 { let (xs, ys) = get_batches (&dataset, split, config.batch_size, config.context_length)?; let (_, loss) = model.forward(&xs, Some (&ys))?; let loss = loss.unwrap (); losses.push (loss.to_scalar::<f32 >()?); } let avg_loss = losses.iter ().sum::<f32 >() / losses.len () as f32 ; out.insert (split.to_owned (), avg_loss); } Ok (out) }
设置一个可以工作的简单模型 这是一个带有嵌入的基本前馈神经网络。它是我们将要开始的基础模型,然后我们将逐步替换其部分内容,直到最终得到 Llama 论文中描述的模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 struct SimpleBrokenModel { embedding: Embedding, mlp: Sequential, config: ModelConfig, } impl SimpleBrokenModel { fn forward (&self , x: &Tensor, targets: Option <&Tensor>) -> Result <(Tensor, Option <Tensor>)> { let embeds = self .embedding.forward(x)?; let logits = self .mlp.forward(&embeds)?; if let Some (targets) = targets { let loss = loss::cross_entropy ( &logits.reshape (((), self .config.vocab_size))?, &targets.reshape (((),))?, )?; Ok ((logits, Some (loss))) } else { Ok ((logits, None )) } } fn load (vb: VarBuilder, config: ModelConfig) -> Result <(Self )> { let embedding = embedding ( config.vocab_size, config.d_model, vb.pp ("model.embed_tokens" ), )?; let mlp = sequential::seq () .add (linear ( config.d_model, config.d_model, vb.push_prefix ("model.fc1" ), )?) .add (Activation::Relu) .add (linear ( config.d_model, config.vocab_size, vb.push_prefix ("model.fc2" ), )?); Ok (Self { embedding, mlp, config, }) } } modeConfig.d_model = 128 ; modeConfig.batch_size = 32 ; let (xs, ys) = get_batches ( &dataset, "train" , modeConfig.batch_size, modeConfig.context_length, )?; let varmap = candle_nn::VarMap::new ();let vb = candle_nn::VarBuilder::from_varmap (&varmap, DType::F32, &Device::Cpu);let model = SimpleBrokenModel::load (vb, modeConfig)?;let (logits, loss) = model.forward(&xs, Some (&ys))?;println! ("{:?} {:?}" , logits, loss);let mut params_count : usize = 0 ;for (name, var) in varmap.data ().lock ().unwrap ().iter () { println! ("{}: {:?}" , name, var.elem_count ()); params_count += var.elem_count (); } println! ("params count: {}" , params_count);
Tensor[dims 32, 16, 65; f32] Some(Tensor[5.266067; f32])
model.fc2.weight: 8320
model.embed_tokens.weight: 8320
model.fc1.bias: 128
model.fc2.bias: 65
model.fc1.weight: 16384
params count: 33217
在这一点上,我们必须开始关注张量的形状,并让矩阵的维度匹配。查看我们模型定义中的这一行:
1 2 3 4 let loss = loss::cross_entropy ( &logits.reshape (((), self .config.vocab_size))?, &targets.reshape (((),))?, )?;
我们必须调整 logits
和 targets
张量的形状,以便在比较时它们的维度匹配。我们使用 reshape
方法来实现这一点。()
参数的意思是“从其他维度推断这个维度”。所以,在这种情况下,我们是在说“将 logits
和 targets
重新调整为具有相同行数的形状,并使用所需的列数来实现这一点”。这是处理批量数据时的常见模式。
让我们训练我们的 SimpleBrokenModel
以确保梯度流动。在确认这一点之后,我们可以替换它的部分内容以匹配 Llama,再次训练并跟踪我们的进展。在这一点上,我开始记录我的训练运行日志,这样如果我搞砸了,我可以轻松地回到之前的运行。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 fn train ( config: ModelConfig, model: &SimpleBrokenModel, opt: &mut AdamW, dataset: &Tensor, vocab: &Vocab, ) -> Result <()> { let mut start_time = std::time::Instant::now (); for epoch in 0 ..config.epochs { let (xs, ys) = get_batches (&dataset, "train" , config.batch_size, config.context_length)?; let (_, loss) = model.forward(&xs, Some (&ys))?; opt.backward_step (&loss.unwrap ())?; if epoch % config.log_interval == 0 { let batch_duration = start_time.elapsed ().as_secs_f32 (); let loss = evaluate_loss (&model, dataset, vocab, config)?; let val_loss = loss.get ("val" ).unwrap (); let eta = batch_duration * (config.epochs - epoch) as f32 ; let eta = eta.round (); println! ( "Epoch: {epoch} | Loss: {val_loss} | Time: {batch_duration} | ETA in seconds {eta}" ); start_time = time::Instant::now (); } } Ok (()) } modeConfig.log_interval = 10 ; modeConfig.epochs = 100 ; let mut opt = candle_nn::AdamW::new (varmap.all_vars (), ParamsAdamW::default ())?;train (modeConfig, &model, &mut opt, &dataset, &vocab)?;let out = evaluate_loss (&model, &dataset, &vocab, modeConfig);println! ("{:?}" , out);
Epoch: 10 | Loss: 3.9159875 | Time: 6.5813394 | ETA in seconds 599
Epoch: 20 | Loss: 3.26492 | Time: 6.3639965 | ETA in seconds 515
Epoch: 30 | Loss: 2.9944448 | Time: 6.3596206 | ETA in seconds 452
Epoch: 40 | Loss: 2.8793342 | Time: 6.357106 | ETA in seconds 388
Epoch: 50 | Loss: 2.7827232 | Time: 6.3562865 | ETA in seconds 324
Epoch: 60 | Loss: 2.764416 | Time: 6.352279 | ETA in seconds 260
Epoch: 70 | Loss: 2.7196321 | Time: 6.356127 | ETA in seconds 197
Epoch: 80 | Loss: 2.7631993 | Time: 6.357493 | ETA in seconds 134
Epoch: 90 | Loss: 2.696882 | Time: 6.358631 | ETA in seconds 70
Epoch: 100 | Loss: 2.670012 | Time: 6.3603354 | ETA in seconds 6
Ok({"train": 2.591057, "val": 2.6625311})
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 fn generate (model: &SimpleBrokenModel, vocab: &Vocab, max_tokens: usize ) -> Result <()> { let mut token_ids = Tensor::zeros ((5 , 1 ), DType::U32, &Device::Cpu).unwrap (); for _ in 0 ..max_tokens { let (logits, _) = model.forward(&token_ids, None )?; assert! (logits.shape ().dims () == [token_ids.dim (0 )?, token_ids.dim (1 )?, 65 ]); let last_step_logits = logits.i ((.., logits.dim (1 )? - 1 ))?; assert! (last_step_logits.shape ().dims () == [token_ids.dim (0 )?, 65 ]); let probs = softmax (&last_step_logits, last_step_logits.dims ().len () - 1 )?; assert! (probs.shape ().dims () == [token_ids.dim (0 )?, 65 ]); let next_token = probs.argmax (probs.dims ().len () - 1 )?; assert! (next_token.shape ().dims () == [token_ids.dim (0 )?]); token_ids = Tensor::cat (&[token_ids, next_token.reshape (((), 1 ))?], 1 )?; } let lines = fs::read_to_string ("./input.txt" ).expect ("Failed to read the file" ); for v in &token_ids.to_vec2 ()? { let text = vocab.decode (v); println! ("{}" , text); } Ok (()) } generate (&model, &vocab, 10 , device)?;
['\nFind!\nD:\nAr t,\nLis sthte o t l',
'\nAnd ronnot ar\nBE:\nKINRDYOrspr;',
'\nI t athe momyengthend thanswal',
'\nFis t bp he\nLacarn.\nA:\nYOMI wi',
'\nWh ly sck\nB-de pll t\nHERIns ou']
这还算不错,但也不算太好。不过现在我们有了一个可以训练到验证损失的工作模型。因此,我们将在此基础上迭代我们的模型,使其更接近 Llama。
Llama 具体细节 Llama 对原始 Transformer 进行了三项架构修改:
用于预归一化的 RMSNorm
旋转嵌入 RoPE
SwiGLU 激活函数
我们将逐一添加每个修改到我们的基础模型,并进行迭代。
RMSNorm 在 Vaswani 2017 中,原始的 Transformer 使用了 BatchNormalization。在 Llama 中,作者使用了 RMSNorm,这是一种在不进行中心化的情况下通过方差来缩放向量的方法。此外,虽然 Vaswani 将归一化应用于注意力层的输出(后归一化),但 Llama 将其应用于输入之前(前归一化)。
这篇文章 对于RMSNorm有一个很好的解释。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 pub struct RmsNorm { scale: Tensor, eps: f64 , } impl RmsNorm { fn new (size: usize , vb: VarBuilder) -> Result <Self > { Ok (RmsNorm { scale: vb.get_with_hints (size, "weight" , Init::Const (1 .))?, eps: 1e-6 , }) } pub fn forward (&self , x: &Tensor) -> Result <Tensor> { let x_sqr = x.sqr ()?; assert! (x_sqr.shape ().dims () == x.shape ().dims ()); let norm_x = (x.mean (D::Minus1)? + self .eps)?.sqrt ()?; assert! (norm_x.shape ().dims () == [x.shape ().dim (0 )?, x.shape ().dim (1 )?]); let x_normed = x.broadcast_div (&norm_x.reshape (( norm_x.shape ().dim (0 )?, norm_x.shape ().dim (1 )?, (), ))?)?; assert! (x_normed.shape ().dims () == x.shape ().dims ()); let x = (x_normed.broadcast_mul (&self .scale))?; Ok (x) } } let varmap = candle_nn::VarMap::new ();let vb = candle_nn::VarBuilder::from_varmap (&varmap, DType::F32, &Device::Cpu);let rms_norms = RmsNorm::new (2 , vb)?;let batch = Tensor::new ( vec! [ vec! [vec! [1f32 , 1f32 ], vec! [1.2f32 , 2f32 ], vec! [3f32 , 3f32 ]], vec! [vec! [4f32 , 43f32 ], vec! [5f32 , 5f32 ], vec! [61f32 , 6f32 ]], ], &Device::Cpu, )?; let vb = candle_nn::VarBuilder::from_varmap (&varmap, DType::F32, &Device::Cpu);let out = rms_norms.forward(&batch)?;
Tensor[dims 2, 3, 2; f32]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 struct SimpleBrokenModel { embedding: Embedding, mlp: Sequential, rms_norm: RmsNorm, config: ModelConfig, } impl SimpleBrokenModel { fn forward (&self , x: &Tensor, targets: Option <&Tensor>) -> Result <(Tensor, Option <Tensor>)> { let embeds = self .embedding.forward(x)?; let normed_embeds = self .rms_norm.forward(&embeds)?; let logits = self .mlp.forward(&normed_embeds)?; if let Some (targets) = targets { let loss = loss::cross_entropy ( &logits.reshape (((), self .config.vocab_size))?, &targets.reshape (((),))?, )?; Ok ((logits, Some (loss))) } else { Ok ((logits, None )) } } fn load (vb: VarBuilder, config: ModelConfig) -> Result <(Self )> { let embedding = embedding ( config.vocab_size, config.d_model, vb.pp ("model.embed_tokens" ), )?; let rms_norm = RmsNorm::new (config.d_model, vb.pp ("model.rms_norm" ))?; let mlp = sequential::seq () .add (linear ( config.d_model, config.d_model, vb.push_prefix ("model.fc1" ), )?) .add (Activation::Relu) .add (linear ( config.d_model, config.vocab_size, vb.push_prefix ("model.fc2" ), )?); Ok (Self { embedding, mlp, config, rms_norm, }) } }
Epoch: 10 | Loss: 4.1559505 | Time: 6.779387 | ETA in seconds 617
Epoch: 20 | Loss: 4.14648 | Time: 6.7727704 | ETA in seconds 549
Epoch: 30 | Loss: 4.1364665 | Time: 6.776428 | ETA in seconds 481
Epoch: 40 | Loss: 4.125594 | Time: 6.772582 | ETA in seconds 413
Epoch: 50 | Loss: 4.120083 | Time: 6.7661977 | ETA in seconds 345
Epoch: 60 | Loss: 4.1099877 | Time: 6.760399 | ETA in seconds 277
Epoch: 70 | Loss: 4.0996284 | Time: 6.7623253 | ETA in seconds 210
Epoch: 80 | Loss: 4.0902996 | Time: 6.761824 | ETA in seconds 142
Epoch: 90 | Loss: 4.0833025 | Time: 6.76845 | ETA in seconds 74
Epoch: 100 | Loss: 4.070025 | Time: 6.7624454 | ETA in seconds 7
Ok({"train": 4.072861, "val": 4.0711236})
从这里得到的结果来看,范化以后,模型的表现并没有提升,所以我们需要继续迭代,只是梯度的下降变得比较平滑了。
Rotary Embeddings RoPE 是一种用于 Transformer 的位置编码方法。在《Attention is All You Need》中,作者提出了两种位置编码方法:学习的和固定的。在 RoPE 中,作者通过旋转嵌入来表示序列中标记的位置,并在每个位置使用不同的旋转角度。
其中的 cos 和 sin 值可以预先计算并缓存,避免重复计算,后续会统一存放在一个缓存结构中。
RoPE 将 hidden_state
中每两个 x 组成的向量与旋转矩阵相乘来实现位置编码。
1 2 3 4 [x0, x1, .... ,xn] y0 = x0 * cos (theta) - x1 * sin (theta) y1 = x0 * sin (theta) + x1 * cos (theta) [y0, y1, ...., yn]
n
是 d_model
的一半。theta
是一个根据位置得到的固定值,计算公式为:theta = m / 10000^(2i/n)
其中,m
是在序列中的位置,i
是在 d_model
中的位置。
这个公式的含义是将特征向量中的 x0
和 x1
进行一个固定的旋转,这个旋转不是通过学习得到的,而是预先计算的。它可以用于表示相对位置信息。
1 2 3 ```Rust pos_index = 0, 中的 x0,x1 pos_index = 1, 中的 x0,x1
隔了一个恒定的调度旋转。
freq_cis缓存住提前算好的cos和sin的值,这部分不用重复计算。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 struct Cache { cos: Tensor, sin: Tensor, } impl Cache { fn new (context_length: usize , n_elem: usize , vb: VarBuilder) -> Result <Cache> { let theta : Vec <_> = (0 ..n_elem) .step_by (2 ) .map (|i| 1f32 / 10000f32 .powf (i as f32 / n_elem as f32 )) .collect (); let theta = Tensor::new (theta.as_slice (), vb.device ())?; let idx_theta = Tensor::arange (0 , context_length as u32 , vb.device ())? .to_dtype (DType::F32)? .reshape ((context_length, 1 ))? .matmul (&theta.reshape ((1 , theta.elem_count ()))?)?; let freq_cis_real = idx_theta.cos ()?; let freq_cis_imag = idx_theta.sin ()?; let cos = freq_cis_real.reshape ((context_length, n_elem / 2 , 1 ))?; let sin = freq_cis_imag.reshape ((context_length, n_elem / 2 , 1 ))?; Ok (Cache { cos, sin}) } }
rope计算的时候
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 fn apply_rotary_emb (&self , x: &Tensor, cache: &Cache) -> Result <Tensor> { let (b_sz, seq_len, n_embd) = x.dims3 ()?; let cos = cache.cos.i (..seq_len)?; let sin = cache.sin.i (..seq_len)?; let cos = cos.broadcast_as ((b_sz, seq_len, n_embd / 2 , 1 ))?; let sin = sin.broadcast_as ((b_sz, seq_len, n_embd / 2 , 1 ))?; let x = x.reshape ((b_sz, seq_len, n_embd / 2 , 2 ))?; let x0 = x.narrow (D::Minus1, 0 , 1 )?; let x1 = x.narrow (D::Minus1, 1 , 1 )?; let dst0 = (x0.broadcast_mul (&cos)? - x1.broadcast_mul (&sin)?)?; let dst1 = (x0.broadcast_mul (&sin)? + x1.broadcast_mul (&cos)?)?; let rope = Tensor::cat (&[&dst0, &dst1], D::Minus1)?.reshape ((b_sz, seq_len, n_embd))?; Ok (rope) }
Self Attention 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 struct AttentionModel { embedding: Embedding, mlp: Sequential, rms_norm: RmsNorm, config: ModelConfig, self_attention: SelfAttention, cache: Cache, } impl AttentionModel { fn forward (&self , x: &Tensor, targets: Option <&Tensor>) -> Result <(Tensor, Option <Tensor>)> { let embeds = self .embedding.forward(x)?; let normed_embeds = self .rms_norm.forward(&embeds)?; let y = self .self_attention.forward(&normed_embeds, &self .cache)?; let logits = self .mlp.forward(&normed_embeds)?; if let Some (targets) = targets { let loss = loss::cross_entropy ( &logits.reshape (((), self .config.vocab_size))?, &targets.reshape (((),))?, )?; Ok ((logits, Some (loss))) } else { Ok ((logits, None )) } } fn load (vb: VarBuilder, config: ModelConfig) -> Result <(Self )> { let embedding = embedding ( config.vocab_size, config.d_model, vb.pp ("model.embed_tokens" ), )?; let rms_norm = RmsNorm::new (config.d_model, vb.pp ("model.rms_norm" ))?; let self_attention = SelfAttention::load (vb.pp ("model.self_attention" ), config)?; let mlp = sequential::seq () .add (linear ( config.d_model, config.d_model, vb.push_prefix ("model.fc1" ), )?) .add (Activation::Relu) .add (linear ( config.d_model, config.vocab_size, vb.push_prefix ("model.fc2" ), )?); Ok (Self { embedding, mlp, config, rms_norm, self_attention, cache: Cache::new (config.context_length, config.d_model, vb)?, }) } }
提示:了解训练时张量维度与推理时张量维度的区别。
虽然在训练时,你可以期望张量维度与模型参数紧密匹配,例如 batch.shape = (config['batch_size'], config['context_window'], config['d_model'])
,但在推理时,你可能需要处理单个示例,例如 batch.shape = (1, 1, config['d_model'])
。因此,你需要确保在 forward
传递中进行索引时,使用从输入派生的形状,而不一定是模型参数。
MultiHeadRopeAttention 让我们为这个单一的注意力头设置一个多头注意力层,看看训练时会发生什么。
这里实现的是GQA的注意力头。n_kv_head=1
时就是MQA,n_kv_head>1
且n_kv_head<n_head
时就是GQA,n_kv_head=n_head
时就是原本的MHA。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 struct MultiHeadAttention { q_proj: Linear, k_proj: Linear, v_proj: Linear, o_proj: Linear, n_head: usize , n_kv_head: usize , head_dim: usize , } impl MultiHeadAttention { fn load (vb: VarBuilder, config: ModelConfig) -> Result <Self > { let q_proj = linear (config.d_model, config.d_model, vb.pp ("model.q_proj" ))?; let k_proj = linear ( config.d_model, (config.d_model / config.n_head) * config.n_kv_head, vb.pp ("model.k_proj" ), )?; let v_proj = linear ( config.d_model, (config.d_model / config.n_head) * config.n_kv_head, vb.pp ("model.v_proj" ), )?; let o_proj = linear (config.d_model, config.d_model, vb.pp ("model.o_proj" ))?; println! ( "MHA config n_head {} n_kv_head {}" , config.n_head, config.n_kv_head ); Ok (Self { q_proj, k_proj, v_proj, o_proj, n_head: config.n_head, n_kv_head: config.n_kv_head, head_dim: config.d_model / config.n_head, }) } fn apply_rotary_emb (&self , x: &Tensor, cache: &Cache) -> Result <Tensor> { let (b_sz, seq_len, h, n_embd) = x.dims4 ()?; let cos = cache.cos.i (..seq_len)?; let sin = cache.sin.i (..seq_len)?; let cos = cos.unsqueeze (1 )?; let sin = sin.unsqueeze (1 )?; let cos = cos.broadcast_as ((b_sz, seq_len, 1 , n_embd / 2 , 1 ))?; let sin = sin.broadcast_as ((b_sz, seq_len, 1 , n_embd / 2 , 1 ))?; let x = x.reshape ((b_sz, seq_len, h, n_embd / 2 , 2 ))?; let x0 = x.narrow (D::Minus1, 0 , 1 )?; let x1 = x.narrow (D::Minus1, 1 , 1 )?; let dst0 = (x0.broadcast_mul (&cos)? - x1.broadcast_mul (&sin)?)?; let dst1 = (x0.broadcast_mul (&sin)? + x1.broadcast_mul (&cos)?)?; let rope = Tensor::cat (&[&dst0, &dst1], D::Minus1)?.reshape ((b_sz, seq_len, h, n_embd))?; Ok (rope) } fn forward (&self , x: &Tensor, cache: &Cache) -> Result <Tensor> { let (b_sz, seq_len, n_embd) = x.dims3 ()?; let q = self .q_proj.forward(x)?; let k = self .k_proj.forward(x)?; let v = self .v_proj.forward(x)?; assert! (n_embd == self .n_head * self .head_dim); let q = q.reshape ((b_sz, seq_len, self .n_head, self .head_dim))?; let k = k.reshape ((b_sz, seq_len, self .n_kv_head, self .head_dim))?; let v = v.reshape ((b_sz, seq_len, self .n_kv_head, self .head_dim))?; let q = self .apply_rotary_emb (&q, cache)?; let k = self .apply_rotary_emb (&k, cache)?; let k = self .repeat_kv (k)?; let v = self .repeat_kv (v)?; let q = q.transpose (1 , 2 )?.contiguous ()?; let k = k.transpose (1 , 2 )?.contiguous ()?; let v = v.transpose (1 , 2 )?.contiguous ()?; let attn = (q.matmul (&k.t ()?)? / (self .head_dim as f64 ).sqrt ())?; let attn = softmax (&attn, D::Minus1)?; let y = attn.matmul (&v)?; let y = y.transpose (1 , 2 )?.reshape (&[b_sz, seq_len, n_embd])?; let y = self .o_proj.forward(&y)?; Ok (y) } fn repeat_kv (&self , x: Tensor) -> Result <Tensor> { let n_rep = self .n_head / self .n_kv_head; if n_rep == 1 { Ok (x) } else { let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4 ()?; let x = x .unsqueeze (3 )? .expand ((b_sz, seq_len, n_kv_head, n_rep, head_dim))? .reshape ((b_sz, seq_len, n_kv_head * n_rep, head_dim))?; Ok (x) } } }
完整模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 struct AttentionModel { embedding: Embedding, mlp: Sequential, rms_norm: RmsNorm, config: ModelConfig, self_attention: MultiHeadAttention, cache: Cache, } impl AttentionModel { fn forward (&self , x: &Tensor, targets: Option <&Tensor>) -> Result <(Tensor, Option <Tensor>)> { let embeds = self .embedding.forward(x)?; let normed_embeds = self .rms_norm.forward(&embeds)?; let y = self .self_attention.forward(&normed_embeds, &self .cache)?; let logits = self .mlp.forward(&y)?; if let Some (targets) = targets { let loss = loss::cross_entropy ( &logits.reshape (((), self .config.vocab_size))?, &targets.reshape (((),))?, )?; Ok ((logits, Some (loss))) } else { Ok ((logits, None )) } } fn load (vb: VarBuilder, config: ModelConfig) -> Result <(Self )> { let embedding = embedding ( config.vocab_size, config.d_model, vb.pp ("model.embed_tokens" ), )?; let rms_norm = RmsNorm::new (config.d_model, vb.pp ("model.rms_norm" ))?; let self_attention = MultiHeadAttention::load (vb.pp ("model.multi_head_attention" ), config)?; let mlp = sequential::seq () .add (linear ( config.d_model, config.d_model, vb.push_prefix ("model.fc1" ), )?) .add (Activation::Relu) .add (linear ( config.d_model, config.vocab_size, vb.push_prefix ("model.fc2" ), )?); Ok (Self { embedding, mlp, config, rms_norm, self_attention, cache: Cache::new (config.context_length, config.d_model/config.n_head, vb)?, }) } }
1 generate (&model, &vocab, 10 , device)?;
['\n\n\n\n\n\n\n\nI\n\nOOOOOOOOOFOOtOOOOOOO',
'\nIIIIII IIIIIIIIIIIIIIIIIIIIIII',
'\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
'\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\naaame',
'\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n']
所以看起来很糟糕。这里发生了什么?让我们通过查看注意力来开始调试。
目前的注意力是没有masked的,任何位置的字符都在关注任何其他位置的字符。 这有什么不好呢?我们试图仅基于之前的标记来预测下一个标记,但这里我们看到模型正在关注之后的标记。 换句话说,模型在作弊,或者从未来泄露信息。这是一个问题,这就是为什么我们需要使用因果掩码。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 pub struct Cache { mask: Tensor, } impl Cache { fn new (context_length: usize , n_elem: usize , vb: VarBuilder) -> Result <Cache> { let mask : Vec <_> = (0 ..context_length) .flat_map (|i| (0 ..context_length).map (move |j| u8 ::from (j > i))) .collect (); let mask = Tensor::from_slice (&mask, (context_length, context_length), vb.device ())?; Ok (Cache { cos, sin, mask }) } } struct MultiHeadAttention { q_proj: Linear, k_proj: Linear, v_proj: Linear, o_proj: Linear, n_head: usize , n_kv_head: usize , head_dim: usize , } impl MultiHeadAttention { fn load (vb: VarBuilder, config: ModelConfig) -> Result <Self > { let q_proj = linear (config.d_model, config.d_model, vb.pp ("model.q_proj" ))?; let k_proj = linear ( config.d_model, (config.d_model / config.n_head) * config.n_kv_head, vb.pp ("model.k_proj" ), )?; let v_proj = linear ( config.d_model, (config.d_model / config.n_head) * config.n_kv_head, vb.pp ("model.v_proj" ), )?; let o_proj = linear (config.d_model, config.d_model, vb.pp ("model.o_proj" ))?; println! ( "MHA config n_head {} n_kv_head {}" , config.n_head, config.n_kv_head ); Ok (Self { q_proj, k_proj, v_proj, o_proj, n_head: config.n_head, n_kv_head: config.n_kv_head, head_dim: config.d_model / config.n_head, }) } fn forward (&self , x: &Tensor, cache: &Cache) -> Result <Tensor> { let (b_sz, seq_len, n_embd) = x.dims3 ()?; let q = self .q_proj.forward(x)?; let k = self .k_proj.forward(x)?; let v = self .v_proj.forward(x)?; assert! (n_embd == self .n_head * self .head_dim); let q = q.reshape ((b_sz, seq_len, self .n_head, self .head_dim))?; let k = k.reshape ((b_sz, seq_len, self .n_kv_head, self .head_dim))?; let v = v.reshape ((b_sz, seq_len, self .n_kv_head, self .head_dim))?; let q = self .apply_rotary_emb (&q, cache)?; let k = self .apply_rotary_emb (&k, cache)?; let k = self .repeat_kv (k)?; let v = self .repeat_kv (v)?; let q = q.transpose (1 , 2 )?.contiguous ()?; let k = k.transpose (1 , 2 )?.contiguous ()?; let v = v.transpose (1 , 2 )?.contiguous ()?; let attn = (q.matmul (&k.t ()?)? / (self .head_dim as f64 ).sqrt ())?; let mask = cache .mask .i ((..seq_len, ..seq_len))? .unsqueeze (0 )? .unsqueeze (0 )? .broadcast_as (attn.shape ())?; let on_true = Tensor::new (f32 ::NEG_INFINITY, attn.device ())?.broadcast_as (mask.shape ().dims ())?; let attn = mask.where_cond (&on_true, &attn)?; let attn = softmax (&attn, D::Minus1)?; let y = attn.matmul (&v)?; let y = y.transpose (1 , 2 )?.reshape (&[b_sz, seq_len, n_embd])?; let y = self .o_proj.forward(&y)?; Ok (y) } fn repeat_kv (&self , x: Tensor) -> Result <Tensor> { let n_rep = self .n_head / self .n_kv_head; if n_rep == 1 { Ok (x) } else { let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4 ()?; let x = x .unsqueeze (3 )? .expand ((b_sz, seq_len, n_kv_head, n_rep, head_dim))? .reshape ((b_sz, seq_len, n_kv_head * n_rep, head_dim))?; Ok (x) } } }
现在,我们可以让注意力激活的上三角部分(对应未来的部分)几乎被归零了。让我们看看训练时会发生什么。
SwiGLU 正如论文中所述,“我们用SwiGLU激活函数替换了ReLU非线性函数……我们使用$\frac{2}{3} 4d$的维度,而不是PaLM中的$4d$。” SwiGLU定义为: $$ \text{SwiGLU}(x) = \text{Swish}_\beta (xW + b) \otimes (xV + c) $$ 其中$\otimes$是逐元素乘积。Swish函数定义为: $$ \text{Swish}_\beta(x) = x \sigma(\beta x) $$ 其中$\beta$是一个可学习的参数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 fn silu (xs: &Tensor) -> Result <Tensor> { xs / (xs.neg ()?.exp ()? + 1.0 )? } struct SwiGLU { c_fc1: Linear, c_fc2: Linear, c_proj: Linear, } impl SwiGLU { fn forward (&self , x: &Tensor) -> Result <Tensor> { let x = (silu (&self .c_fc1.forward(x)?)? * self .c_fc2.forward(x)?)?; self .c_proj.forward(&x) } fn load (vb: VarBuilder, cfg: ModelConfig) -> Result <Self > { let h_size = cfg.d_model; let i_size = cfg.hidden_dim; let c_fc1 = linear (h_size, i_size, vb.pp ("gate_proj" ))?; let c_fc2 = linear (h_size, i_size, vb.pp ("up_proj" ))?; let c_proj = linear (i_size, h_size, vb.pp ("down_proj" ))?; Ok (Self { c_fc1, c_fc2, c_proj, }) } }
一个llama block res 两次,一次在attention之前,一次在swiglu之前。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 struct Block { rms_1: RmsNorm, attn: MultiHeadAttention, rms_2: RmsNorm, swiglu: SwiGLU, } impl Block { fn forward (&self , x: &Tensor, cache: &Cache) -> Result <Tensor> { let residual = x; let x = self .rms_1.forward(x)?; let x = (self .attn.forward(&x, cache)? + residual)?; let residual = &x; let x = (self .swiglu.forward(&self .rms_2.forward(&x)?)? + residual)?; Ok (x) } fn load (vb: VarBuilder, cfg: ModelConfig) -> Result <Self > { let attn = MultiHeadAttention::load (vb.pp ("self_attn" ), cfg)?; let swiglu = SwiGLU::load (vb.pp ("mlp" ), cfg)?; let rms_1 = RmsNorm::new (cfg.d_model, vb.pp ("input_layernorm" ))?; let rms_2 = RmsNorm::new (cfg.d_model, vb.pp ("post_attention_layernorm" ))?; Ok (Self { rms_1, attn, rms_2, swiglu, }) } }
现在,让我们通过创建块来添加多个层的最后完整的模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 struct Llama { embedding: Embedding, blocks: Vec <Block>, ln_f: RmsNorm, lm_head: Linear, cache: Cache, config: ModelConfig, } impl Llama { pub fn forward ( &self , x: &Tensor, targets: Option <&Tensor>, ) -> Result <(Tensor, Option <Tensor>)> { let (_b_sz, _seq_len) = x.dims2 ()?; let mut x = self .embedding.forward(x)?; for block in &self .blocks { x = block.forward(&x, &self .cache)?; } let x = self .ln_f.forward(&x)?; let logits = self .lm_head.forward(&x)?; if let Some (targets) = targets { let loss = loss::cross_entropy ( &logits.reshape (((), self .config.vocab_size))?, &targets.reshape (((),))?, )?; Ok ((logits, Some (loss))) } else { Ok ((logits, None )) } } pub fn load (vb: VarBuilder, config: ModelConfig) -> Result <Self > { let embed_layer = embedding ( config.vocab_size, config.d_model, vb.pp ("model.embed_tokens" ), )?; let lm_head = linear (config.d_model, config.vocab_size, vb.pp ("lm_head" ))?; let ln_f = RmsNorm::new (config.d_model, vb.pp ("model.norm" ))?; let blocks : Vec <_> = (0 ..config.n_layers) .map (|i| Block::load (vb.pp (format! ("model.layers.{i}" )), config).unwrap ()) .collect (); Ok (Self { embedding: embed_layer, blocks, ln_f, lm_head, config, cache: Cache::new (config.context_length, config.d_model / config.n_head, vb)?, }) } }
扩展 这里主要是训练的部分,在推理的过程中还涉及到一个比较重要的kv cache。 kv cache主要是缓存自回归过程中的kv,这个kv是不变的,因为这个k 和 v 只和之前的token有关系,所以可以缓存下来, 这样在推理的时候就不用重复计算了。 围绕着prompt产生第一个Token的prefill的阶段和计算完prompt之后的decode阶段也是目前业界比较关注的推理优化的方向。 源代码在这里