LESSON 24 · 卷 大语言模型

残差连接与层归一化

96 层 Transformer——梯度怎么从第 96 层传回第 1 层而不消失?

第 1 站

96 层网络,梯度能传回第一层吗?

GPT-3 有 96 层 Transformer Block。训练时损失在最后一层算出,梯度要一路反向传播——从第 96 层传回第 1 层。

第 13 课讲过,梯度每穿一层,就要乘一次那层的导数。假设每层导数平均是 0.9:

经过 96 层:

第 1 层收到的梯度只有原来的万分之零点六——几乎是零。它的参数得不到有效更新信号,等于白占内存。 反过来,要是每层导数平均 1.1,,梯度爆炸、数值溢出,训练当场崩盘。 连乘这件事,要么把梯度碾成零,要么把它吹上天,极难刚好停在 1 附近

这个敌人你已经撞见过两回:第 13 课反向传播,梯度沿链式法则一层层连乘;第 17 课的 RNN, 把同一个矩阵沿「时间」连乘几十步,于是记忆消退。这一次只是换了方向——沿「深度」连乘 96 层。 同一道数学难题第三次找上门,所以解法的灵魂,也和当年给 LSTM 修「记忆高速公路」如出一辙。

96 次连乘把梯度逼到极端。有没有办法让梯度「绕过」一些层,少连乘几次?
第 2 站

残差连接:加一条直通路

残差连接的做法简单粗暴:把这一层的输入,原样加到它的输出上

普通层 
残差层 
等一下:把 x 加回去,输出不就被改了?

会变——但不会变坏。关键在于 训练出来的,不是写死的。 加上这条捷径后,这一层要学的目标也跟着变了:原来它得亲手算出完整的目标输出 , 现在它只需算出「目标与输入之间的, 最后 又自动拼回 。 训练时网络早把这条「+x」算在内,会主动调整 去配合它——最终该输出什么还是什么,只是「这一层负责算哪部分」换了个分工。这正是「残差」(residual = 余项、差值)这个名字的由来。

更妙的是,这个分工让「按需微调」变得极其廉价。设想某一层最该做的事其实是「别添乱,把输入原样传下去」: 普通层得逼一整组权重恰好凑出一个恒等映射,意外地难;残差层只要让 就行——「什么都不学」反而成了最容易达成的默认,需要改时再让 吐出一点小修正。 于是深层网络可以「能不动就不动,要动只动一点」,叠到上百层也不怕被中间某层搅乱。

(唯一的前提:相加要求 形状相同。Transformer 里每个子层的输入输出都保持同样的 维,所以这一加天然成立。)

输入 xf(x)(注意力 or FFN)梯度:∂f/∂x,可能很小捷径 x梯度:1f(x) + x输出 y
图 24-1残差连接给梯度开了一条「高速公路」(绿色捷径)。反向传播时梯度沿两条路回传:一条穿过 f(x)(可能很小),一条走捷径(梯度恒为 1,不衰减)。

关键全在求导。对 求梯度:

那个孤零零的 「+1」就是命根子。哪怕 小到 0.01, 这一层的总梯度也是 1.01——稳稳压在 1 附近。于是 96 层连乘不再是 , 而是接近 这种温和的数:既没碾成零,也没吹爆。 96 层残差网络,相当于给梯度铺了一条可以一路直通回第 1 层的高速路。

回头看第 21 课——搭那台「最简 Transformer」时,我们说过真实的块里,注意力和 FFN 后面还各跟着一个「Add & Norm」稳定器,当时一笔带过。 那个「Add」,就是这里的 ——它一直在偷偷给梯度修高速路。并排的「Norm」,是下一站的主角。

第 3 站

LayerNorm:把数值拉回正常量级

看完上一站,你心里大概已经打鼓了:那条「只进不减」的高速路,每过一块就往上一截 , 这么一路加下去,数值会不会越滚越大?——确实会。但先别急着把账全算在残差头上: 「深层网络的激活值尺度会到处乱跑、得不时校准」本来就是个通病,早在没有残差的年代就存在; 残差只是让它以「单调膨胀」这种特别显眼的方式冒了出来。所以这一站的主角——归一化, 并不是专门给残差打的补丁,而是深层网络通用的「稳压器」。

先说清楚这一站要收拾的隐患:激活值的量级会慢慢漂——某些维度越加越大、另一些越来越小, 十几层后,同一条向量里可能一个分量是 10000、另一个是 0.001。这种悬殊会让后面层的权重彻底失准。

「漂」不是抽象的担心,拿数字走一遍就看得见。不过在走之前,得先把两个记号交代清楚,免得看着发懵:

  • 记号 ① :我们给每一块编个号 (念作 L,就是「第几块」), 表示数据流过第 块之后的那条向量。 于是上一站残差的「累加」写成式子就是 , 读作「下一块的向量 = 这一块的向量 + 这一块算出的增量 」。
  • 记号 ② :一条向量到底「多大」,用它的长度来量,记作 ,叫范数—— 其实就是向量的长度(各分量平方求和再开根号)。

有了这两个记号就好说了:假设向量刚进网络时 ≈ 1,之后每过一块就被 + 顶大一点。 把几个残差块首尾叠起来、逐块标出输入输出,量级滚雪球的过程就一目了然:

残差高速路 x只进不减 ↓读入 x1 块 · 注意力 f₁算出一截增量 f+f读入 x2 块 · FFN f₂算出一截增量 f+f读入 x3 块 · 注意力 f₃算出一截增量 f+f读入 x4 块 · FFN f₄算出一截增量 f+f‖x‖≈1.0‖x‖≈1.7‖x‖≈2.9‖x‖≈4.4‖x‖≈6.6‖x‖≈28…到第 12 块,已近 30 倍← 入口输入,量级正常← 上一块的输出 = 下一块的输入
图 24-2把多个残差块首尾叠起来看:每块只把自己算出的增量 加到同一条「高速路」 上,上一块的输出就是下一块的输入(节点上的 ‖x‖)。高速路只进不减,量级一块接一块越滚越大——示意里 4 块就到 6.6 倍,到第 12 块近 30 倍(数值为示意,真实模型残差流范数确实随深度近乎线性增长)。这就是为什么必须有人在每块结束时把量级拉回 ≈1,那个人就是下面的 LayerNorm。

那……量级一路滚大,不拉回来会怎样?麻烦在于:后面每一层的权重,都是在「输入大约是某个量级」的前提下练出来的。 一旦输入层层暴涨,后面的层就像拿着厘米刻度的尺子去量公里——越往后越乱套。具体会撞上两堵墙:

其一,数值精度撑不住。大模型常用 16 位浮点做训练,位数有限:量级一旦失控,要么直接溢出成 inf / NaN、训练当场崩盘; 要么大数把小数「吃掉」——10000 和 0.001 相加,那个 0.001 直接被舍没了,信息凭空蒸发。

其二,该敏感的地方变迟钝。还记得第 11 课的 Softmax 吗?它对输入的大小极其敏感:喂进去的数一旦过大, 输出就会一边倒地挤成「非 0 即 1」,注意力只死盯一个词、梯度也几乎归零,网络再也学不动。

正是为了堵上这两堵墙——LayerNorm(层归一化)登场了。它的活,就是在每个子层后把向量拍回标准量级。它对每一个词向量单独做三步:

① 求这条向量的均值和标准差 
② 归一化到均值 0、标准差 1 
③ 再缩放平移(参数可学) 

拿一条 4 维向量 手算一遍:

均值 μ = (8+2+2+4)/4 = 4
方差 = [(8−4)² + (2−4)² + (2−4)² + (4−4)²]/4 = (16+4+4+0)/4 = 6 → σ = √6 ≈ 2.45
归一化 x̂ = [(8−4)/2.45, (2−4)/2.45, (2−4)/2.45, 0] ≈ [1.63, −0.82, −0.82, 0]

一条原本 [8,2,2,4]、量级乱七八糟的向量,被拉成了一条均值 0、标准差 1 的「标准」向量。 第 ③ 步的 可学习参数:归一化是死规矩,但网络可以靠 γ、β 把分布再「拧」回它真正想要的范围—— 既享受稳定,又不丢灵活。

归一化前(参差不齐)9k30.00212k0.080k归一化后(整齐)-1.2-0.50.10.81.2-0.3
图 24-3LayerNorm 把一条向量从「量级乱飞」拉到「均值 0、标准差 1」。统一了量级,后面层的权重才能稳定工作,不被某个异常大的值带偏。
第 4 站

一个不起眼却要命的细节:Norm 放前还是放后

残差 + LayerNorm 都有了,但它俩谁先谁后,竟决定了深层模型训不训得起来。这是 2019 年后才被业界看清的一个坑:

Post-LN(原始 2017)残差高速路子层 f(x)Add(+x)LayerNormNorm 卡在高速路上 → 不稳Pre-LN(现代 GPT)残差高速路(畅通)LayerNorm子层 f(·)Add(+x)Norm 移到支路上 → 主路畅通
图 24-4Post-LN 把 LayerNorm 压在残差主路之后,等于在高速路上设了收费站,深层时梯度仍不稳,得靠脆弱的「学习率预热」哄着训。Pre-LN 把 Norm 挪到子层之前的支路上,残差主路全程畅通——这才是 GPT 能稳稳堆到几十上百层的关键。

写成式子,差别就一目了然——Pre-LN 的那个 完全没被 Norm 碰过,高速路纯净:

Post-LN  ← x 被 LN 二次加工,主路受扰
Pre-LN  ← x 原样直通,只在支路上 Norm

还有个更省的变体叫 RMSNorm(LLaMA 等在用):它干脆连均值都不减, 只把向量除以它自己的「均方根」来定量级()。 少算一步、参数更少,效果却几乎一样——又一次印证了这一卷反复出现的审美:能简则简,简到不能再简为止。

第 5 站

总结

本课核心 · TAKEAWAY

残差连接 y = f(x) + x:那个「+1」给梯度留一条不衰减的高速路,96 层连乘也不消失。LayerNorm:把每条向量拉回均值 0、标准差 1,挡住数值漂移。 再加一个讲究——Pre-LN 让残差主路保持纯净,才有了今天能稳稳堆到上百层的深层模型。

这一课你亲手推导了

  • 深度连乘的诅咒:0.9⁹⁶≈0.00006(消失)或 1.1⁹⁶≈12000(爆炸),极难停在 1。
  • 残差连接:∂y/∂x = ∂f/∂x + 1,那个「+1」把每层梯度钉在 1 附近。
  • LayerNorm:[8,2,2,4] → 减均值 4、除 σ≈2.45 → [1.63,−0.82,−0.82,0],再乘 γ 加 β。
  • 放前 vs 放后:Pre-LN(x + f(LN(x)))让残差主路畅通,比 Post-LN 稳得多;RMSNorm 更省。
小测验

学习小测验

做完这一课,来检测一下核心知识点。选出你的答案后点击「提交」,即可看到正确选项与讲解。

Q1在 96 层这样的超深 Transformer 里,「残差连接(residual/skip connection)」最重要的作用是?
Q2层归一化(Layer Normalization)在 Transformer 中主要起什么作用?
NEXT · 第 25 课

预训练 · 监督微调 · 强化学习

骨架、零件、稳定器都齐了——可一个只会「预测下一个词」的模型,怎么变成会回答、会帮忙的 ChatGPT?

0 人点赞,0 人看过