KV 缓存、稀疏注意力与 FlashAttention
注意力的账单是 O(n²)——处理一本书,工程师用什么手段把它变快?
注意力的账单:O(n²)
第 21 课结尾我们记了一笔账:注意力要让每个 token 和所有 token 两两比较, 个 token 就有 个格子。当时一带而过,这一课来还这笔账。
把数字摆具体:GPT-4 上下文 128k token,若全用标准注意力,一层的注意力表就有 128000² ≈ 164 亿个格子,乘上约 120 层——一次推理光注意力就得算两万亿次,FFN 还没算进去。 这道平方的账单,是长上下文路上最硬的一堵墙。
第一条路:稀疏注意力——干脆别算那么多
最直觉的省法:每个 token 别再看全部 n 个位置,只看左右 W 个邻居。 这叫滑动窗口注意力,计算量从 一下降到 (W 固定,相当于随 n 线性增长):
可只看邻居有个明显代价:看不到远处——开头那句关键信息,到结尾就盲了。 所以实战中会打几个补丁:留几个「全局 token」(比如开头的指令)让所有位置都能看到它、它也能看到全文; 或者让窗口「跳着看」(膨胀窗口),用同样的格子数覆盖更远。Longformer(2020)、Mistral(2023)就靠这类组合,把上下文撑到几十万 token。
稀疏注意力改变了数学结果——它赌「远处不重要」,赌对了省钱、赌错了漏信息。 可有时候我们就是想要完整、精确的注意力(每个词看所有词),一格不漏, 只是嫌它太慢、太占内存。这就引出了第二条、完全不同的路。
先摘个唾手可得的果子:KV 缓存
在动用聪明算法之前,生成阶段有一笔纯属浪费的重复账,白送的优化。回想第 23 课的自回归:模型一个词一个词地吐, 每吐一个,就把它接到末尾再跑一遍。问题来了——每跑一遍,它都把前面所有词的 K、V 从头又算了一次。
可那些词没变啊!「I」的 K、V,在第 2 步、第 3 步、第 100 步都是同一个值。 于是有了 KV 缓存:算过的 K、V 存起来,下一步只算新来的那一个词,其余直接从缓存取:
代价是拿显存换速度:缓存要一直占着内存,序列越长、占得越多—— 这也是为什么超长对话特别吃显存。后来的 MQA / GQA(第 20 课提过的多头变体)就是为了缩小这个缓存而生的: 让多个头共享同一份 K、V,缓存立刻小一大截。省内存这件事,从注意力一路追到了缓存。
第二条路:FlashAttention——不减计算,只改顺序
这条路最反直觉:一格都不少算,结果和标准注意力分毫不差,却能快好几倍。秘密不在「算」,在「搬」。GPU 有两层存储:
- HBM(显存):几十 GB,容量大,但读写慢;
- SRAM(片上缓存):只有几十 MB,但读写快十几倍。
标准注意力的真正瓶颈,是那张 的大矩阵:它太大,塞不进 SRAM, 只能写进慢吞吞的 HBM,再读回来做 softmax,再写回去……来回搬运几趟。瓶颈根本不是计算,是搬运。
可这里有个拦路虎:softmax 要先看到一整行所有分数(求最大值、求总和)才能归一化。 要是把行切成小块、一块一块进 SRAM,看第一块时还不知道后面有没有更大的数,怎么算?FlashAttention(2022,斯坦福)的核心绝活,就是一招「在线 softmax」:边读边修。
先来 [2, 1]:记下当前最大 m=2,当前分母 ℓ = e2−2+e1−2 = 1+0.37 = 1.37
再来 [4, 0]:发现 4 比旧最大 2 还大 → 把旧账按 e2−4=0.14 缩小一下:1.37×0.14 = 0.19
新分母 ℓ = 0.19 + e4−4+e0−4 = 0.19 + 1 + 0.02 = 1.21
直接对 [2,1,4,0] 整行算:e−2+e−3+e0+e−4 = 0.14+0.05+1+0.02 = 1.21 ✓ 完全一致
诀窍就在那句「发现更大的数,就把已经算的部分等比缩小一下」。靠一个running 最大值和一个running 分母, softmax 不用一次看到整行,也能流着算、且结果精确。于是注意力可以分块塞进 SRAM 算完就走,那张 n² 大矩阵压根不用落地到 HBM:
它不改变结果、不丢信息,所以成了无脑就该开的默认项:今天 PyTorch、各大训练框架都内置了它。 稀疏注意力(省算)和 FlashAttention(省搬)还能叠加使用,长上下文这才真正跑得起来。
总结
注意力的 O(n²) 墙有几种攻法:稀疏注意力少算(只看邻居,是近似、会丢远处);KV 缓存不重算(生成时存下旧词的 K/V);FlashAttention 不减计算、只改搬运(在线 softmax 让大矩阵不落显存,结果还精确)。 正是这些工程把 128k+ 的长上下文从妄想变成了日常。
这一课你亲手推导了
- 账单:注意力 O(n²),128k token 一层就 164 亿格子 × 120 层。
- 稀疏注意力:滑动窗口 O(n·W) + 全局 token / 膨胀窗口补远处;代价是近似。
- KV 缓存:生成时缓存旧词 K/V,每步只算新词,从 O(n²) 降到每步 O(n);MQA/GQA 进一步缩缓存。
- FlashAttention:在线 softmax([2,1,4,0] 流式算出同样的分母 1.21)让大矩阵不落 HBM,快 2~4 倍且精确。
学习小测验
做完这一课,来检测一下核心知识点。选出你的答案后点击「提交」,即可看到正确选项与讲解。
MoE 混合专家架构
GPT-4 据说有 1.8 万亿参数——可每次推理只用其中一小部分?「又大又快」是怎么做到的?