一、总论
虽然我已经在之前的文章中讨论过一遍张量是如何求导的了,并且附上了详细的数学推导。但是我觉得那次的讨论还是有些过于偏向于数学的严谨,而忽略了实际使用中的直观。
所以我打算再推一遍,省略一些数学细节,但是更加注重实际的使用,包括对于计算和内存开销的估算,矩阵的形状等。
二、规律
2.1 LOSS 是标量
无论 LLM Forward 最终是生成一个 token,还是一段 output 序列,还是一个 batch 的 output 序列组,最终的 loss 都是一个标量。
这点似乎还是有些反直觉的,这是因为在很多科普文章中,为了简化描述,往往只会用一个 token 来举例子,而实际上,训练往往是以 batch 的形式进行,并且与不是只生成一个 token,而是一串 token。
在实际生产中,大模型 forward 的结果其实一个形状为 $[B, S, V]$ 的三维张量,用于表示概率,其中有:
- $B$:Batch Size
- $S$:Sequence Length
- $V$:Vocabulary Size
对于每个 token 来说,它生成的是一个概率分布,形状是与词表大小相同的一维向量 ,它会与一个标量的 label(可以理解为“正确答案”)去计算交叉熵,也就是看一下这个概率分布合理与否。最终会得到一个标量的 loss 。
那么此时,我们会得到一个 loss 矩阵 $L$ ,形状是 $[B, S]$ 。问题在于,我们会直接拿着这个二维张量去进行 backward 吗?并不会,我们会对 $L$ 进行 reduce,最终得到一个标量 $l$:
至于为什么不直接用二维张量去 backward,我觉得是出于计算的简便性的考虑。我们都知道雅可比矩阵(Jacobian Matrix)是一个二维张量,而其实因变量和自变量都是一维张量。之所以会发生“升维”,是因为我们要描述“每个自变量的分量”对于“每个因变量的分量”的影响,所以维度就升高了。
如果我们直接使用 $[B, S]$ 的 $L$ 进行 backward,显然梯度的形状应该会变成高维张量,这显然不利于我们进行梯度下降。而且最关键的是,就算我们能忍受高维张量,到最终这些高维张量还是要 reduce 后再更新模型的。
也就是即使我们维护了“权重张量对于不同输出的影响”这个张量,最后我们还是要把不同输出对应的影响全加起来,然后再更新权重,相当于没变化。
总结一下,我们将上面的标量 loss 记录为 $l$ 。
2.2 形状相同
梯度的形状是与原张量的形状是相同的。
这个规律其实和前面的规律息息相关。
然后我们需要理解一下在 backward 中提到的“梯度下降法”的具体含义。梯度最重要的,是确定“因变量”和“自变量”。自变量很好理解,我们需要计算哪个张量的梯度,哪个张量就作为自变量。而因变量就比较迷惑了,先说结论,它一直是 $l$ 。
那么为什么说因变量比较迷惑呢?这是因为 backward 中充分应用了链式法则,在链式法则中,会引入很多的“临时梯度”,这些临时梯度的因变量不再是 $l$ 了。当这些“临时梯度”与“最终梯度”一起出现在公式中的时候,就容易让人感到迷惑了。
那么考虑一个在 LLM 中出现的张量 $X$ ,无论 $X$ 在模型的哪个位置(是最后一层,还是第一层),发挥什么作用(是 FFN 的权重,还是 Attention 的权重,还是激活值,还是偏移值),形状如何(是一维向量,还是二维矩阵,还是考虑 batch 的高维张量),它其中的每个分量都会影响 $l$ 。
所以因此我们得到的梯度 $\frac{\partial l}{\partial X}$ 的形状就一定和 $X$ 保持相同,因为它就表示了 $X$ 中的每个分量对于 $l$ 的影响。
为了强调这个规律,我们引入了新的标记:
其中 $X_G$ 与 $X$ 的形状完全相同。
当然梯度下降法也就很好表示了,因为形状相同(这么看形状也必须相同),说白了就是:
2.3 矩阵乘法很简单
对于 Forward 过程:
无论 $A, B$ 的形状是什么,有 backward 过程:
也就是说,虽然理解并推导链式法则 backward 是一件很困难的事情,但是最终的结论是非常简单且 general 的。当我们拿到一个因变量的梯度的时候,自变量的梯度计算会变得很容易记忆。
如果有两个自变量,那么某个自变量的梯度就是因变量梯度与另一个自变量的乘积(当然可能还需要转置)。
那难道 LLM 中就都是 $Y = AB$ 这种简单形式吗?难道 Attention 和 FFN 这种复杂的网络结构,也能用矩阵乘法表示吗?还这能,这是因为:
- 我们并不限制 $A, B$ 的形状,也就是无论他们是一维的,还是二维的,上面的式子都成立。
- 虽然 Attention 这种结构乍一看很复杂,但是它都可以被拆成很多个矩阵乘法,比如说 $QK$ 和 $PV$ 等。
- 确实激活函数或者 norm 函数(包括 softmax)无法用这个模式硬套,但是用这种 vec2vec 的梯度下降,本身也不是计算的大头。
三、应用
在这一章里面,我们用上面介绍的规律来实战一下,其实主要就是矩阵乘法梯度的规律。
3.1 FFN
可以被理解成两个线性映射层:
我们已经拿到了 $Y_G$ ,带入上面的规律可知:
在计算 $W_G$ 的时候,需要使用到激活值 $X$,这就要求在 forward 的时候,要将这个临时的激活值进行保留。
我们计算 $W_G$ 是出于更新 $W$ 的目的,那么我们为什么要计算 $X_G$,这是因为 $X$ 此时是自变量,而它同时也是因变量,所以需要计算它保证链式传播。
3.2 Attention
Attention 看起来很复杂,但是实际上经过拆解,并不难。在 forward 过程中,有:
经过这么一拆解,就算不算,也知道 backward 可以表示成一个很简单的形式了。
我们已经拿到了 $O_G$ ,然后有:
同时有:
这两个计算的开销是非常大的,因为我们要保留形状为 $[n, n]$ 的 $P, P_G$ ($n$ 是序列长度),这都是非常高昂的开销。
我们不考虑 $softmax$ 的梯度,但是总之 $S_G$ 的形状也是 $[n, n]$ 。
当我们有了 $S_G$ 后,我们就可以推算 $Q_G, K_G$ 了,有:
FlashAttention 之所以这么厉害,不止是因为它在 forward 的卓越贡献,能够避免 $[n, n]$ 矩阵,在 backward 中,它利用保存在 SRAM 里的 Block 形式的 $Q, K, V$ 重新算一遍 Forward,当场算出局部的 $P$,随即算出梯度,直接加和,同样不需要保存 $[n, n]$ 矩阵。