残差连接与层归一化
96 层 Transformer——梯度怎么从第 96 层传回第 1 层而不消失?
96 层网络,梯度能传回第一层吗?
GPT-3 有 96 层 Transformer Block。训练时损失在最后一层算出,梯度要一路反向传播——从第 96 层传回第 1 层。
第 13 课讲过,梯度每穿一层,就要乘一次那层的导数。假设每层导数平均是 0.9:
第 1 层收到的梯度只有原来的万分之零点六——几乎是零。它的参数得不到有效更新信号,等于白占内存。 反过来,要是每层导数平均 1.1,,梯度爆炸、数值溢出,训练当场崩盘。 连乘这件事,要么把梯度碾成零,要么把它吹上天,极难刚好停在 1 附近。
这个敌人你已经撞见过两回:第 13 课反向传播,梯度沿链式法则一层层连乘;第 17 课的 RNN, 把同一个矩阵沿「时间」连乘几十步,于是记忆消退。这一次只是换了方向——沿「深度」连乘 96 层。 同一道数学难题第三次找上门,所以解法的灵魂,也和当年给 LSTM 修「记忆高速公路」如出一辙。
残差连接:加一条直通路
残差连接的做法简单粗暴:把这一层的输入,原样加到它的输出上。
会变——但不会变坏。关键在于 是训练出来的,不是写死的。 加上这条捷径后,这一层要学的目标也跟着变了:原来它得亲手算出完整的目标输出 , 现在它只需算出「目标与输入之间的差」, 最后 又自动拼回 。 训练时网络早把这条「+x」算在内,会主动调整 去配合它——最终该输出什么还是什么,只是「这一层负责算哪部分」换了个分工。这正是「残差」(residual = 余项、差值)这个名字的由来。
更妙的是,这个分工让「按需微调」变得极其廉价。设想某一层最该做的事其实是「别添乱,把输入原样传下去」: 普通层得逼一整组权重恰好凑出一个恒等映射,意外地难;残差层只要让 就行——「什么都不学」反而成了最容易达成的默认,需要改时再让 吐出一点小修正。 于是深层网络可以「能不动就不动,要动只动一点」,叠到上百层也不怕被中间某层搅乱。
(唯一的前提:相加要求 与 形状相同。Transformer 里每个子层的输入输出都保持同样的 维,所以这一加天然成立。)
关键全在求导。对 求梯度:
那个孤零零的 「+1」就是命根子。哪怕 小到 0.01, 这一层的总梯度也是 1.01——稳稳压在 1 附近。于是 96 层连乘不再是 , 而是接近 这种温和的数:既没碾成零,也没吹爆。 96 层残差网络,相当于给梯度铺了一条可以一路直通回第 1 层的高速路。
回头看第 21 课——搭那台「最简 Transformer」时,我们说过真实的块里,注意力和 FFN 后面还各跟着一个「Add & Norm」稳定器,当时一笔带过。 那个「Add」,就是这里的 ——它一直在偷偷给梯度修高速路。并排的「Norm」,是下一站的主角。
LayerNorm:把数值拉回正常量级
看完上一站,你心里大概已经打鼓了:那条「只进不减」的高速路,每过一块就往上加一截 , 这么一路加下去,数值会不会越滚越大?——确实会。但先别急着把账全算在残差头上: 「深层网络的激活值尺度会到处乱跑、得不时校准」本来就是个通病,早在没有残差的年代就存在; 残差只是让它以「单调膨胀」这种特别显眼的方式冒了出来。所以这一站的主角——归一化, 并不是专门给残差打的补丁,而是深层网络通用的「稳压器」。
先说清楚这一站要收拾的隐患:激活值的量级会慢慢漂——某些维度越加越大、另一些越来越小, 十几层后,同一条向量里可能一个分量是 10000、另一个是 0.001。这种悬殊会让后面层的权重彻底失准。
「漂」不是抽象的担心,拿数字走一遍就看得见。不过在走之前,得先把两个记号交代清楚,免得看着发懵:
- 记号 ① :我们给每一块编个号 (念作 L,就是「第几块」), 表示数据流过第 块之后的那条向量。 于是上一站残差的「累加」写成式子就是 , 读作「下一块的向量 = 这一块的向量 + 这一块算出的增量 」。
- 记号 ② :一条向量到底「多大」,用它的长度来量,记作 ,叫范数—— 其实就是向量的长度(各分量平方求和再开根号)。
有了这两个记号就好说了:假设向量刚进网络时 ≈ 1,之后每过一块就被 + 顶大一点。 把几个残差块首尾叠起来、逐块标出输入输出,量级滚雪球的过程就一目了然:
那……量级一路滚大,不拉回来会怎样?麻烦在于:后面每一层的权重,都是在「输入大约是某个量级」的前提下练出来的。 一旦输入层层暴涨,后面的层就像拿着厘米刻度的尺子去量公里——越往后越乱套。具体会撞上两堵墙:
其一,数值精度撑不住。大模型常用 16 位浮点做训练,位数有限:量级一旦失控,要么直接溢出成 inf / NaN、训练当场崩盘; 要么大数把小数「吃掉」——10000 和 0.001 相加,那个 0.001 直接被舍没了,信息凭空蒸发。
其二,该敏感的地方变迟钝。还记得第 11 课的 Softmax 吗?它对输入的大小极其敏感:喂进去的数一旦过大, 输出就会一边倒地挤成「非 0 即 1」,注意力只死盯一个词、梯度也几乎归零,网络再也学不动。
正是为了堵上这两堵墙——LayerNorm(层归一化)登场了。它的活,就是在每个子层后把向量拍回标准量级。它对每一个词向量单独做三步:
拿一条 4 维向量 手算一遍:
方差 = [(8−4)² + (2−4)² + (2−4)² + (4−4)²]/4 = (16+4+4+0)/4 = 6 → σ = √6 ≈ 2.45
归一化 x̂ = [(8−4)/2.45, (2−4)/2.45, (2−4)/2.45, 0] ≈ [1.63, −0.82, −0.82, 0]
一条原本 [8,2,2,4]、量级乱七八糟的向量,被拉成了一条均值 0、标准差 1 的「标准」向量。 第 ③ 步的 是可学习参数:归一化是死规矩,但网络可以靠 γ、β 把分布再「拧」回它真正想要的范围—— 既享受稳定,又不丢灵活。
一个不起眼却要命的细节:Norm 放前还是放后
残差 + LayerNorm 都有了,但它俩谁先谁后,竟决定了深层模型训不训得起来。这是 2019 年后才被业界看清的一个坑:
写成式子,差别就一目了然——Pre-LN 的那个 完全没被 Norm 碰过,高速路纯净:
还有个更省的变体叫 RMSNorm(LLaMA 等在用):它干脆连均值都不减, 只把向量除以它自己的「均方根」来定量级()。 少算一步、参数更少,效果却几乎一样——又一次印证了这一卷反复出现的审美:能简则简,简到不能再简为止。
总结
残差连接 y = f(x) + x:那个「+1」给梯度留一条不衰减的高速路,96 层连乘也不消失。LayerNorm:把每条向量拉回均值 0、标准差 1,挡住数值漂移。 再加一个讲究——Pre-LN 让残差主路保持纯净,才有了今天能稳稳堆到上百层的深层模型。
这一课你亲手推导了
- 深度连乘的诅咒:0.9⁹⁶≈0.00006(消失)或 1.1⁹⁶≈12000(爆炸),极难停在 1。
- 残差连接:∂y/∂x = ∂f/∂x + 1,那个「+1」把每层梯度钉在 1 附近。
- LayerNorm:[8,2,2,4] → 减均值 4、除 σ≈2.45 → [1.63,−0.82,−0.82,0],再乘 γ 加 β。
- 放前 vs 放后:Pre-LN(x + f(LN(x)))让残差主路畅通,比 Post-LN 稳得多;RMSNorm 更省。
学习小测验
做完这一课,来检测一下核心知识点。选出你的答案后点击「提交」,即可看到正确选项与讲解。
预训练 · 监督微调 · 强化学习
骨架、零件、稳定器都齐了——可一个只会「预测下一个词」的模型,怎么变成会回答、会帮忙的 ChatGPT?