LESSON 26 · 卷 大语言模型

KV 缓存、稀疏注意力与 FlashAttention

注意力的账单是 O(n²)——处理一本书,工程师用什么手段把它变快?

第 1 站

注意力的账单:O(n²)

第 21 课结尾我们记了一笔账:注意力要让每个 token 和所有 token 两两比较, 个 token 就有 个格子。当时一带而过,这一课来还这笔账。

1k tokens100万 次比较(一段文字8k tokens6400万 次比较(一篇长文32k tokens10亿 次比较(一本书1M tokens万亿 次比较(一部百科全书序列长度翻倍,计算量翻 4 倍——O(n²) 涨得吓人
图 26-1注意力矩阵的大小是序列长度的平方。32k token(约一本书)就有 10 亿个格子,而且每一层都要重算一遍

把数字摆具体:GPT-4 上下文 128k token,若全用标准注意力,一层的注意力表就有 128000² ≈ 164 亿个格子,乘上约 120 层——一次推理光注意力就得算两万亿次,FFN 还没算进去。 这道平方的账单,是长上下文路上最硬的一堵墙。

n² 个权重,难道个个同等重要?能不能只算关键的那一小撮?
第 2 站

第一条路:稀疏注意力——干脆别算那么多

最直觉的省法:每个 token 别再看全部 n 个位置,只看左右 W 个邻居。 这叫滑动窗口注意力,计算量从 一下降到 (W 固定,相当于随 n 线性增长):

标准注意力(全部 n²)滑动窗口(W=2)全算 → 49 个格子只算对角带 → 31 个(n 越大,省得越多)
图 26-2滑动窗口只算对角带(每个词与左右 W 个邻居),大片空格直接跳过。序列越长,省得越夸张。

可只看邻居有个明显代价:看不到远处——开头那句关键信息,到结尾就盲了。 所以实战中会打几个补丁:留几个「全局 token」(比如开头的指令)让所有位置都能看到它、它也能看到全文; 或者让窗口「跳着看」(膨胀窗口),用同样的格子数覆盖更远。Longformer(2020)、Mistral(2023)就靠这类组合,把上下文撑到几十万 token。

这条路是「近似」,会丢东西

稀疏注意力改变了数学结果——它赌「远处不重要」,赌对了省钱、赌错了漏信息。 可有时候我们就是想要完整、精确的注意力(每个词看所有词),一格不漏, 只是嫌它太慢、太占内存。这就引出了第二条、完全不同的路。

第 3 站

先摘个唾手可得的果子:KV 缓存

在动用聪明算法之前,生成阶段有一笔纯属浪费的重复账,白送的优化。回想第 23 课的自回归:模型一个词一个词地吐, 每吐一个,就把它接到末尾再跑一遍。问题来了——每跑一遍,它都把前面所有词的 K、V 从头又算了一次

可那些词没变啊!「I」的 K、V,在第 2 步、第 3 步、第 100 步都是同一个值。 于是有了 KV 缓存:算过的 K、V 存起来,下一步只算新来的那一个词,其余直接从缓存取:

生成第 5 个词时,前 4 个词的 K/V 早就算过了IK,V 已缓存 ✓loveK,V 已缓存 ✓thisK,V 已缓存 ✓sunnyK,V 已缓存 ✓day只算这一个!没缓存:每步重算全部 → 生成 n 个词总共 O(n²)有缓存:每步只算 1 个 → 每步 O(n),省掉大量重复
图 26-3KV 缓存:已经生成的词,它们的 Key/Value 存进缓存,再不重算;每生成一个新词,只为这一个词算 K/V,然后拿它的 Query 去和「缓存里全部 K/V」做注意力。这是所有大模型推理的标配加速。

代价是拿显存换速度:缓存要一直占着内存,序列越长、占得越多—— 这也是为什么超长对话特别吃显存。后来的 MQA / GQA(第 20 课提过的多头变体)就是为了缩小这个缓存而生的: 让多个头共享同一份 K、V,缓存立刻小一大截。省内存这件事,从注意力一路追到了缓存。

第 4 站

第二条路:FlashAttention——不减计算,只改顺序

这条路最反直觉:一格都不少算,结果和标准注意力分毫不差,却能快好几倍。秘密不在「算」,在「搬」。GPU 有两层存储:

  • HBM(显存):几十 GB,容量大,但读写
  • SRAM(片上缓存):只有几十 MB,但读写快十几倍

标准注意力的真正瓶颈,是那张 的大矩阵:它太大,塞不进 SRAM, 只能写进慢吞吞的 HBM,再读回来做 softmax,再写回去……来回搬运几趟。瓶颈根本不是计算,是搬运。

可这里有个拦路虎:softmax 要先看到一整行所有分数(求最大值、求总和)才能归一化。 要是把行切成小块、一块一块进 SRAM,看第一块时还不知道后面有没有更大的数,怎么算?FlashAttention(2022,斯坦福)的核心绝活,就是一招「在线 softmax」:边读边修。

要对一行分数 [2, 1, 4, 0] 做 softmax,但每次只来两个数:
先来 [2, 1]:记下当前最大 m=2当前分母 ℓ = e2−2+e1−2 = 1+0.37 = 1.37
再来 [4, 0]:发现 4 比旧最大 2 还大 → 把旧账按 e2−4=0.14 缩小一下:1.37×0.14 = 0.19
     新分母 ℓ = 0.19 + e4−4+e0−4 = 0.19 + 1 + 0.02 = 1.21
直接对 [2,1,4,0] 整行算:e−2+e−3+e0+e−4 = 0.14+0.05+1+0.02 = 1.21 ✓ 完全一致

诀窍就在那句「发现更大的数,就把已经算的部分等比缩小一下」。靠一个running 最大值和一个running 分母, softmax 不用一次看到整行,也能流着算、且结果精确。于是注意力可以分块塞进 SRAM 算完就走,那张 n² 大矩阵压根不用落地到 HBM

标准:大矩阵反复进出 HBMQ,K,V (HBM)QKᵀ → 写回 HBMSoftmax → 写回 HBM× V → 输出Flash:分块在 SRAM 算完分块读 Q,K,V → SRAMSRAM 内在线 softmax只把最终结果写回 HBM大矩阵不落地,快 2~4 倍、省显存数倍
图 26-4FlashAttention 的关键:靠在线 softmax,把注意力分块在快速的 SRAM 里算完,那张巨大的 n² 矩阵从不写进慢速 HBM。数学结果和标准注意力一字不差,只是搬运少了——快 2~4 倍,显存省好几倍。

它不改变结果、不丢信息,所以成了无脑就该开的默认项:今天 PyTorch、各大训练框架都内置了它。 稀疏注意力(省算)和 FlashAttention(省搬)还能叠加使用,长上下文这才真正跑得起来。

第 5 站

总结

本课核心 · TAKEAWAY

注意力的 O(n²) 墙有几种攻法:稀疏注意力少算(只看邻居,是近似、会丢远处);KV 缓存不重算(生成时存下旧词的 K/V);FlashAttention 不减计算、只改搬运(在线 softmax 让大矩阵不落显存,结果还精确)。 正是这些工程把 128k+ 的长上下文从妄想变成了日常。

这一课你亲手推导了

  • 账单:注意力 O(n²),128k token 一层就 164 亿格子 × 120 层。
  • 稀疏注意力:滑动窗口 O(n·W) + 全局 token / 膨胀窗口补远处;代价是近似。
  • KV 缓存:生成时缓存旧词 K/V,每步只算新词,从 O(n²) 降到每步 O(n);MQA/GQA 进一步缩缓存。
  • FlashAttention:在线 softmax([2,1,4,0] 流式算出同样的分母 1.21)让大矩阵不落 HBM,快 2~4 倍且精确。
小测验

学习小测验

做完这一课,来检测一下核心知识点。选出你的答案后点击「提交」,即可看到正确选项与讲解。

Q1标准自注意力处理长度为 n 的序列,其计算/内存开销大致是?这是长文本的主要瓶颈。
Q2下面哪一项的目的是「在自回归生成时,避免重复计算已经处理过的 token 的 Key/Value」?
NEXT · 第 27 课

MoE 混合专家架构

GPT-4 据说有 1.8 万亿参数——可每次推理只用其中一小部分?「又大又快」是怎么做到的?

0 人点赞,0 人看过