DeepSeek V3分析
During the pre-training stage, training DeepSeek-V3 on each trillion tokens requires only 180K
H800 GPU hours, i.e., 3.7 days on our cluster with 2048 H800 GPUs.
DeepSeek实现了非常便宜的训练成本,是一个700B的MoE模型。
基础设施
- 计算集群:在配备 2048 个 NVIDIA H800 GPU 的集群上训练,节点内通过 NVLink 和 NVSwitch 连接,节点间使用 InfiniBand 互连。
- 训练框架:基于 HAI - LLM 框架,采用 16 路管道并行(PP)、64 路专家并行(EP)和 ZeRO - 1 数据并行(DP)。设计 DualPipe 算法减少管道气泡并重叠计算与通信,开发高效的跨节点全对全通信内核,优化内存占用,无需使用昂贵的张量并行(TP)。
- FP8 训练:提出 FP8 混合精度训练框架,对多数计算密集型操作采用 FP8 精度,对部分关键操作保留原始精度,引入细粒度量化策略、提高累积精度、采用 E4M3 格式及在线量化,还降低了内存和通信开销。
- 推理与部署:部署在 H800 集群上,通过分离预填充和解码阶段确保服务水平目标(SLO)和高吞吐量。预填充阶段最小部署单元为 4 节点 32 个 GPU,采用特定并行策略和冗余专家策略确保负载均衡;解码阶段最小部署单元为 40 节点 320 个 GPU,采用相应并行策略和冗余专家策略,并探索动态冗余策略。
- 硬件设计建议:针对通信硬件,期望未来硬件能卸载通信任务,统一网络接口;针对计算硬件,建议提高 FP8 GEMM 累积精度、支持细粒度量化、在线量化和转置 GEMM 操作。
并行度配置
在prefill阶段,attention模块采用4路张量并行+8路数据并行,moe模块采用32路专家并行。这样并行的目的是在满足首token时延的要求下,最大化系统吞吐(和训练任务类似)。
在decode阶段,DeepSeek-V3采取320路专家并行(256个小专家+64个热点专家),有效降低解码时延,并缓解负载不均衡的问题。
DeepSeek-V3 采用了多种并行策略,包括 16 路流水线并行(PP),这一策略有助于提高训练效率,加快模型的处理速度。同时,还应用了 64 路专家并行(EP),且在 8 个节点上进行,能够充分发挥多节点的计算优势。此外,ZeRO-1 数据并行(DP)也被运用到训练中,进一步提升了模型的训练效果。
ZeRO-1 优化器被切分到不同的GPU上。 《大模型动力引擎——PyTorch性能与显存优化手册》有提到这个优化,总结的很好。
假设我们有N=64块GPU进行数据并行训练,在ZeRO-1阶段,优化器的状态量首先被分散存储到所有GPU中,此时单张GPU上的内存使用量骤降到(4+4+8/64)*7.5=60.9GB。ZeRO-2阶段进一步地将模型的梯度也分散存储,此时单张GPU上的内存使用量便是(4+(4+8)/64)7.5=31.4GB。而ZeRO-3阶段将模型的参数也分散存储到N个节点,此时每张GPU的内存消耗只有(4+4+8)/647.5=1.875GB。从单卡需要120GB到仅需不到2GB内存,这个优化效果是不是有点惊艳?不过需要再次强调的是,这样巨大的显存优化是有代价的,显存切分的程度越高,相应的通信开销也会增加。因此,根据实际需求合理地进行显存切分是非常重要的。
MLA
采用类似 LoRA 的架构,借助一个低秩矩阵 “compressed laten vector”,kvcache 仅需对低秩的 key-value 对以及附带旋转位置编码(RoPE)的 key 进行缓存。
MoE
除了针对 Top k、routed experts 运用添加了激活函数的加权求和方式外,还额外引入了 shared experts。在 gate 的激活函数里增添一个 bias,以此来化解 balance 失衡的难题,在训练阶段,通过调节这个 bias 对 balance 状况予以奖惩,这一调节过程被称作 bias update speed。
就一个 batch、一个序列而言,每个 token 倘若倾向于特定的一些 expert,那么未被选中的 expert 实际上仅相当于训练了极小的 batch size,或者极短的序列,正因如此,才有了这样一种策略,用以平衡 expert 的 batch size 以及序列当中的 token 数量,毕竟序列通常都很长。
DeepSeek-V3 着重凭借辅助损失策略达成负载均衡,与此同时,引入互为补充的序列平衡损失,以防单个序列内部出现极度不平衡的现象。
MTP
类似于 speculative decoding,它同样会计算多个 token,不过具体方式存在一定差异。其 embedding 与 output head 是共用的,这一点和 sd 里的 Medusa 有所不同,Medusa 是由多个头来推测不同位置,而 MTP 则是依靠多个相同的头(只是 attention 有别)去推断不同位置。
MTP 的核心目的在于提升主模型的性能表现,在推理阶段能够直接将 MTP 模块舍去,主模型依旧可以独自正常运作。不仅如此,MTP 模块还能够应用于推测解码环节,以此进一步优化生成延迟问题,让整个流程更加高效流畅。
DualPipe
双流水线pipeline的优化。它实现了前向和后向过程中计算与通信阶段的重叠,有效解决了跨节点专家并行带来的通信负载问题。
FP8
能够不依赖硬件能力做FP8精度的训练,这个点是非常厉害的。
首先,为提高模型训练速度,大部分核心计算操作(尤其是 GEMM 运算),均采用 FP8 精度实现。这些 GEMM 运算接收 FP8 格式的张量输入,输出 BF16 或 FP32 格式的结果。如图6所示,线性运算相关的三个 GEMM 操作,包括 Fprop(前向传播)、Dgrad(激活值反向传播)和 Wgrad(权重反向传播),均采用 FP8 执行。这种设计策略理论上将计算速度提升至原有 BF16 方法的两倍。同时,FP8 格式的 Wgrad GEMM 使得激活值能够以 FP8 格式存储用于反向传播,显著降低了内存使用量。