Transformer 计算相关
也是惭愧,早就该补上这一课的,做CS336的大作业才把这个给搞明白。
本文假设读者已经懂了什么是Transformer基础架构,以及对应的基础知识
参数量计算
我们这里就用CS 336里面的架构为例

这里我们可以看到LLM看着大,其实很大一部分是Transformer Block的重复,一个Transformer Block 里面又有3个主要的模块,Multi Head Attention (MHA), Feed Forward Network (FFN) 以及RMSNorm,接下来我们先从Transformer Block开始来计算参数量,然后再考虑其他部分。这里简单提及一下常见的缩写
$B$ = Batch size, $S$ = Seqence Length, $H = d$ = Hidden Layer Size, $V$ = Vocabulary size $d_{ff} = \frac{8}{3}H$
Transformer Block
RMSNorm
首先我们要知道RMSNorm的公式, 这里$X \in R^{1\times d} = {x_1, x_2, …, x_d}$, $g \in R^{1\times d}$
$$
RMSNorm(x_i) = \frac{x_i}{RMS(X)}g_i\
RMS(X) = \sqrt{\frac{\sum_{i=1}^{d}{x_i^2}}{d} + \epsilon}
$$
所以一层RMSNorm就有$H$个参数
MHA
对于MHA,我们一般是对于输入X,我们有$W_q, W_k, W_v, W_o \in R^{d \times d}$, 也就是说,直接是$4H^2$, 就这么简单
FFN
现在都是用的Swiglu FFN了,但是为了保持和传统的FFN一样,swiglu FFN的中间层是$d_{ff}= \frac{8}{3}d$,而传统的FFN中间层是$d_{ff} = 4d$, 所以哪怕用swiglu FFN,或者是传统的FFN,其实参数量都是一样的,传统的计算方法是
$$
2 \times H \times 4H = 8H^2
$$
swiglu FFN的公式是
$$
SwiGLUFFN(x) = W_3(SiLU(W_2X)\otimes (W_1X))
$$
这里$W_1 \in R^{d \times d_{ff}}$, $W_2 \in R^{d \times d_{ff}}$, $W_3 \in R^{d_{ff} \times d}$,所以总参数量就是
$$
3 \times d \times d_{ff} = 3 \times d \times \frac{8}{3}d = 8d^2 = 8H^2
$$
所以对于一个Transformer Block 来说,一层里面包含2个RMSNorm + 1个MHA + 1个FFN,所以加起来就是
$$
2\times H + 1\times 4H^2 + 1\times 8H^2 = 12H^2 + 2H
$$
我们假设有$L$层Transformer Block, 所以Transformer Blocks的参数量就是
$$
L(12H^2 + 2H) = 2L(6H^2 + H)
$$
Embedding + Readout Layer
这俩是互为相反形状的一个大矩阵,$W \in R^{V \times d}$, 所以这两层的参数量就是 $2VH$
再加上最后一层单独的RMSNorm的H个参数,我们现在知道了纯参数有
$$
2VH + 2L(6H^2 + H) + H
$$
所以就从储存来说,纯参数量大小只和 $V, H, L$ 有关,我们来个实例,GPT3万恶之源了属于是, 也就是宣称有175B的参数,具体是以下构成的
| Item | Number |
|---|---|
| L | 96 |
| H | 12288 |
| V | 50257 |
| S | 2048 |
| A | 96 |
所以我们计算一下总的参数
1 | L = 96 |
175B,很合理,那么存储这些数据需要多少硬盘空间呢,如果我们用bf16来存,那么1个参数就要吃掉16bit也就是2Byte
1 | P * 2 / 1024**3 |
326.3 GB, 还不错,一张SSD就能够带走。但是这个仅仅是推理,如果要训练呢?
AdamW 参数量计算
首先对于每一个参数,我们求导都会占用相同的分量
然后我们看看AdamW的公式

舒服了,$m$和$v$都是参数量大小的,记录了每一个参数的一阶和二阶动量,所以优化器这里的额外参数量是$3P$, 这还没完,因为如果我们要求梯度,还得保存正向的值
Activation 参数量计算 (这部分可以牺牲速度来换显存)
如果显存实在不够,这个可以被省略掉,但是我一般都不开gradient checkpoint,太慢,由于反向传播需要每一层的输入值,这个我这样解释可能比较清楚
$$
Y = W_2\sigma(XW_1)
$$
典型的两层神经网络,所以说如果我们对$W_2$ 求个导,就可以得到
$$
\frac{\partial Loss}{\partial W_2} = \frac{\partial Loss}{\partial Y}\frac{\partial Y}{\partial W_2} = [\sigma(XW_1)]^T \frac{\partial Loss}{\partial Y}
$$
你看,是不是输入就需要保存下来了,不保存也可以,再算一次呗,但是计算比保存贵,所以我们需要计算每一层的计算结果
RMSNorm
根据之前的公式,RMSNorm的输入是 BSH, 两个输入都是这个大小,所以是$2BSH$
MHA
这里得看一下MHA的公式,相信大家肯定都会背诵了
$$
MHA(X) = softmax(\frac{QK^T}{\sqrt{d}})V
$$
MHA 里面首先是QKVO 矩阵的输入之后拿去计算,这一部分本身的输入是X,没关系,但是拿去计算都要作为输入的,所以$4BSH$, 这个应该问题不大。然后是$QK^T$ 之后,输入到softmax 的输入,是$BAS^2$, 这里的A 是Attention header 个数,Softmax 之后要乘以$V$, 所以这一步的输入也是$BAS^2$,
所以MHA 里面的Activation 是
$$
4BHS + 2BAS^2
$$
FFN
SwiGLU比传统FFN会带来更多的激活,这里看着swiGLU的公式以及上面的例子,我们就知道,需要保存$W_1X, W_2X$ 的结果用作下一步计算,这里是$2 \times BSd_{ff}$, 然后SiLU的结果会用作下一步点乘的输入,所以这个大小也是$BSd_{ff}$, 一共是$3BSd_{ff} = 8BSH$,这一层的输出是$BSH$,正好交给下一层的Norm了
所以单层Transformer Block的激活是
$$
2BSH + 4BSH + 2BAS^2 + 8BSH = 14BSH + 2BAS^2
$$
Others
最后一层的RMSNorm 是 $BSH$, 读出层的输入是$BSH$, 如果读出层还要做Loss函数,输入就是$BSV$
所以总的Activation就是
$$
Activation = \underbrace{2BSH}{\text{Last RMS + Readout}}
+
\underbrace{BSV}{\text{Loss Function}}
+
\underbrace{L\left(
14BSH
+
2BAS^2
\right)}_{\text{all Transformer layers}}
$$
所以如果要训练,需要的显存大小是
$$
\underbrace{P}_\text{Model} + \underbrace{P}_\text{Gradient} + \underbrace{2P}_\text{AdamW} + Activation
$$
现在我们继续以GPT3 为例,上面两个多余的参数这不就有用了?如果我们就用1个Batch 来炼丹
1 | S = 2048 |
所以如果用的是bf16 ,优化器是FP32,而且优化器还要存一份FP32的P做一次高精度优化,那么需要吃掉的显存是
$$
\underbrace{2P}_\text{Model, BF16} + \underbrace{3\times4P}_\text{optimizer FP32} + \underbrace{2P}_\text{Gradient BF16} + \underbrace{2\times Activation}_\text{BF16} = 2817.73GB
$$
其中我们把Batch size 有关的计算,也就是activation里面的B设置为1,那么这个值就是随着batch size正比例变化的,是207.3GB, 所以一共是2.8TB的显存呢,老黄一张A100是80G,8卡一个机器,所以需要
1 | (16*P + 2*activation) / 1024 ** 3 / (80*8) |
5台机器才能够跑得起来呢,所以pipeline parallel 不可避免了,而当时官方训练的时候B = 1536, 怪不得当时要吃掉一万多张V100呢
也不知道现在V100已经日益落伍的今天,是否能够,旧时王谢堂前燕,飞入寻常百姓家。孕育出了ChatGPT雏形的高贵V100在经历过黑无天日的AI压榨过后,能否以平价售出,在我等垃圾佬Homelab里面安享晚年呢?
计算量 FLOP 计算
众所周知,参数量不等于计算量,所以我们还需要计算 计算量
这里小常识,已知LLM里面大量的操作都是矩阵乘法,对于矩阵 $A \in R^{M \times N}, B \in R^{N \times K}, C = AB$来说,矩阵乘法,对于$C \in R^{M \times K}$, 每一个元素都是A的行和B的列相乘再相加得到的,所以每一个元素都有K个乘法 + (K-1)次加法,所以可以粗略看错2K次计算,所以对于这样的矩阵相乘,计算量就是
$$
2MNK
$$
了解这个我们就来计算一下前向Forward的计算量,记$N = B \times S$
RMSNorm
主要是一个 [N, H] x [H] 的点乘,所以会占用$NH$的计算,不算多,后面会省略掉,Softmax也是,属于GEMV级别的,算存比太低,卡显存了
MHA
QKVO四次变换,所以是 $4 \times 2 \times NH^2 = 8NH^2$,
$QK^T$ 是 [N, H] x [H, N] = [N, N], 计算量是 $2HN^2$, $Score \times V$ 是 [N, N] x [N, H] = [N, H], 也是 $2HN^2$, 结果是$4HN^2$
所以加一块儿是
$$
4HN^2 + 8NH^2
$$
FFN
我们还是用SwiGLU为例,传统的用上述方法也可以计算
两次 [N, H] x [H, d_ff] = [N, d_ff] 的操作,是$2 \times 2NHd_{ff}$, 还有一次映射回来d的操作,是[N, d_ff] x [d_ff, H] = [N, H], 占用$2NHd_{ff}$,一共是
$$
6NHd_{ff} = 6NH \times \frac{8}{3}H = 16NH^2
$$
所以一个Transformer Block 的计算量是
$$
24NH^2 + 4HN^2 = 2N(12H^2 + 2HN)
$$
上文提到了模型单层的参数量
$$
P_0 = 12H^2 + 2H
$$
所以正比一下,对于一个参数量为$P$的模型前向计算量约为 $2NP$,$N = BS$ 所以一个Token就会吃掉大约$2P$的计算量
反向传播
本质是矩阵乘法的反向传播,对于一个线性层,我们需要计算2次梯度,而我们之前的计算 计算量的全是通过矩阵乘法计算的
- W本身的梯度 $\nabla W = X^T \nabla Y$
- X的梯度,用于向前传递$\nabla X = \nabla Y W^T$
所以计算量是前向的2倍
因此可以粗略得到,训练的FLOPs = 3倍前向FLOPs
前向FLOPS 约等于2N倍参数, 如果要训练1个Token,就需要吃掉6P的计算量
感觉这个比推Activate要简单
那么继续按照GPT3的案例,看看要训练多久啊,我们由于不知道一共吃了多少个Batch,所以我们用6P原则粗略计算一下,查询到一共吃了300B的Token,所以说
1 | 300e9 * 6 * P |
那么一张A100的BF16的FLOPS是312 TFLOPS,也就是一张A100我们需要
1 | 300e9 * 6 * P / 312e12 / 3600 / 24 |
1万多天,人生不过短短3万多天, 如果我们8卡一个集群,开9个集群,不考虑通信延迟之类的要多久呢?
1 | 300e9 * 6 * P / (312e12 * 72) / 3600 / 24 |
那么如果我们能用上目前最强的GB200 NVL72 Rack,也就是72个集群一共约180PFLOS嗯造呢?
1 | 300e9 * 6 * P / (180e15) / 3600 / 24 |
也就是说这个速度提升了8倍。A100是2020年发布的,B200 是2024年发布的,也就是说4年的时间,老黄把AI性能是实打实提升了8倍, 也就是说每一年比上一年提升20%,这个速度不知道各位看官觉得是挤牙膏还是如何