- 在解决的是什么问题?想让 self-attention 加速
- 为何成功,标志/准是什么?尝试了独特的角度:让它使用 IO-感知的方法,利用好 GPU 层次结构的显存读写
- 在前人基础上的关键创新是什么? 减少 HBM 的读写次数,多用更快的 shared memory。基于 flash attention 还提出了更快的近似算法:block-sparse attention
- 关键结果有哪些?速度快,因此能允许更长的序列在一些NLP的比赛里分数更高
- 有哪些局限性?如何优化?
- 这个工作可能有什么深远的影响?
虽然 transformers 越来越大,深,但是使用更长的上下文依然很难,因为自注意力模块核心的时间和显存复杂度都是序列长度的二次方。所以一个核心问题是让attention更快、显存更高效是否能帮助 transformer 模型克服他们在长序列(long sequences)情况下运行时和显存挑战。目前的 GPT3 被限制在了 2k 的长度。而 FLASH 可以长到8k,最长64k。这样还可以用到高分辨率的VIT上。
有很多近似的attention方法想减少计算和显存开销。包括 sparse-approximation, low-rank approximation,以及他们的组合。尽管他们减少到线性或者接近线性序列长度的计算开销,但是很多并没有明显的加速效果,因此使用不广泛。一个主要原因是他们集中在减少 FLOPS(可能跟加速不直接关联),因此忽视了仿存方面(IO)的开销
本文里,我们认为之前没考虑到的一个方法是:让注意力算法对 IO-感知,仔细审计对不同级别快慢不同显存的读写(比如快的 GPU 上的片上 SRAM和相对慢的 GPU带宽显存,或者 HBM,见下图1)。现代的 GPU 上,计算速度比显存速度要快非常多,transformer里的大部分操作都是仿存速度为瓶颈。IO-aware 算法对类似内存制约的操作上很关键,即当读取和写入 数据在执行时间上占大头的情况:例如数据库 joins,图片处理,数值线性代数(numerical linear algebra)。但是 Python接口的深度学习框架比如 PyTorch 或者 TF 没有暴露这种显存访问的精细接口。
我们提出的 FA,是一种新的 attention 算法,可以用更少的内存访问来完成一样 的attention运算。目标是避免从 HBM 里读取和写入attention矩阵。这需要:
- 计算 softmax 规约时,不要访问整个输入
- backward时,无须fwd过程存储大得到中间attention矩阵
我们用了两个熟知的技术来解决这些挑战:
- 重建了 attention 的计算过程,把输入切分成了 blocks,遍历了好几轮输入的blocks,因此能增量执行softmax 规约(也叫做 tiling)
- 存储了fwd过程里的 softmax 规约的因子,可以在 bwd 过程里在片上快速重计算 attention,这样比标准的从 HBM 里读取 attention 矩阵要快。我们在 CUDA 里实现了 FA 来获得更细粒度的内存访问控制,把所有 attention 操作都融合到一个 GPU kernel(matmul, dropout, softmax, mask, matmul)。尽管因为重计算而导致 FLOPS 增加,但是算法依然更快(在 GPT-2上是 7.6倍),使用更少的显存:相比标准的attention,是输入长度的线性,因此 HBM 访问猛降。
我们分析了 IO 复杂度,证明了它需要 O(N^2d^2M^-1) HBM 访问,d是 head dimension,M 是 SRAM 的大小,与之相比标准 attention 是 O(Nd+N^2)。对于典型的 d 和 M,FA 需要的 HBM 访问是多倍的少于 标准 attention 的(最大9倍,见图2)。而且,我们提供了下界的证明,说明没有其他 attention 算法能渐近地改进在所有SRAM大小上的 HBM 访问次数
我们还展示了 FA 可以作为实现潜在的近似 attention 算法的有用原语,因为克服了他们的访存开销。为了作为概念的验证,我们实现了 block-sparse FA,是比 FA 快2-4倍的稀疏 attention 算法,能扩展到 64k 的序列长度。证明了 block-sparse FA 比 FA 的 IO 复杂度要好,是倍数于稀疏率。我们在第五节讨论了更多扩展(attention在多个GPU上,kernel回归,block-sparse 矩阵)。开源了代码
验证了 FA 在模型训练上的加速效果,提高了模型质量,方法是给更长的上下文建模。也 benchmark 了运行时和显存开销
-
更快的模型训练 FA 在 BERT-large(seq len 512)上比 MLPerf 1.1 里的记录快15%,GPT2(seq length 1K)比 HuggingFace 和 Megatron-LM 快3倍,在 long-range arena (seq length 1K-4K)上快2.4倍
-
更高的模型质量 支持了 16K 的序列长度。Block-sparse FA 支持 64K
-
Benchmarking Attention 比标准实现,在常见序列长度 128~2K上 快到3倍,能扩展到 64K。长度到了 512 之后,FA 比已有的任何 attention 算法要快而且更节省显存。而长度超过 1K 后,一些近似的 attentioin 方法(比如 Linformer)开始变的更快。另外,block-sparse FA 比任何我们已知的近似方法要快
GPU 显存层次 GPU 显存层次(下图1左)。例如 A100 GPU上,有 40-80G的高带宽显存(High bandwidth memory),带宽是 1.5-2.0TB/s,芯片上108个 streaming multiprocessors 里的每个里有 192KB的 SRAM(v100是128K),带宽大约是 19TB/s。可见 SRAM 速度比 HBM 高一个数量级,但是大小上却小很多数量级(40G vs 10M)。由于计算速度相对仿存速度要快很多,因此操作的瓶颈越来越多是显存的访问(HBM)。因此利用更快的 SRAM 就很重要。
性能特性. 根据 op 的计算和仿存的关系,可以分为计算密集型(compute-bound)和仿存密集型(memory-bound)。通常是通过算术密度(arithmetic intensity),即每个字节的显存访问上的算数操作次数
- 计算密集型:有很多算术操作,而仿存很少。典型例子:inner dimension很大的 matrix multiply,有很大channels 的 conv
- 显存密集型:典型例子包括其他的op:elementwise(activation,dropout),和规约(sum, softmax, bn, ln)
Kernel 融合 最常用的加速显存密集型算子的方法是 kernel fusion:如果对输入有多次op运算,可以从 HBM 里加载一次,而不是每次运算都加载一次。编译器可以自动融合很多 elementwise 运算。但是从模型训练上下文里,中间值依然要写入 HBM,因为 BWD 里需要使用,这降低了简单的 kernel 融合的有效性。
有图里,可以看出他计算时:Q、K、V都是在行上切分,即切分N,d这个维度是不切分的,不然就没法在inner loop 里计算出 Sij了
给定输入序列 Q、K、V属于 R^(Nxd),其中N 是序列长度,d是 head的维度,需要计算 attention 的输出 O属于 R^(Nxd):
S = Q*K^T 属于 R^(NxN), P = softmax(S),属于 R^(NxN), O=P*V,属于 R^(Nxd)
其中 softmax 是逐行计算的。
标准的实现把矩阵 S 和 P 写到 HBM 里去,需要花费 O(N^2)的显存。通常 N>>d(比如 GPT2里,N=1024,d=64)。其中的一些算子是仿存密集型的(比如 softmax),大量的 HBM 访问导致速度慢。 这个问题还会被其他在 attention 矩阵上进行的挨个元素的操作,比如应用在 S 上的 mask,应用在 P 上的 dropout 给拖慢。因此有很多尝试是想把多个逐个元素访问的操作融合到一起,比如 fusing masking和softmax (ls 里也是这样的)
在3.2节,展示了标准 attention 实现会有与序列长度 N 成二次方关系的 HBM 访问。我们也比较了两者之间FLOPS和HBM 访问的次数
我们要展示如何用更少的 HBM 读取/写入来计算一样的attention,而不需要存储巨大的中间矩阵来给 bwd 使用。这让算法既显存高效,速度又快。分析IO复杂度,发现我们的方法需要更少的 HBM 访问。
下面专注在 forward 上的分析;附录 B 包含bwd的细节
主要思路是把 Q、K、V 切分成块,从慢的 HBM 加载到快的 SRAM,然后计算对应的 attention 输出。通过给每个block的输出乘上right normalization因子然后再做加法,可以最终得到准确的值。
Tiling: Softmax 和 K 的列是耦合的,所以使用 scaling 分解大的 softmax。为了数值计算稳定,softmax 如下计算: softmax(x) = f(x)/l(x)
这里对softmax 的改进,跟计算m(x)的改进方法类似:
分块计算出来 x1和x2的max,然后就可以再增量计算出他俩整体的max
而FA的FWD 算法里,不仅是m,l和O都是用这种算法 $$ \begin{aligned} m(x^{(1)}) \end{aligned} $$
当第二块计算出来时,就知道这俩整体的情况了 $$ \begin{aligned} \max(x) &=\max(x_1^{(1)},\ldots, x_N^{(1)},x_1^{(2)},\ldots, x_M^{(2)}) \ &=\max(\max(x_1^{(1)},\ldots, x_N^{(1)}), \max(x_1^{(2)},\ldots, x_M^{(2)}))\ &=\max(m(x^{(1)}), m(x^{(2)}))
\end{aligned} $$
通过额外记录一些统计关系(m(x), l(x)),可以每次计算一块的 softmax。因此把输入的 Q、K、V分片为 blocks(Algo 1 line 3),使用额外的统计值(line 10)来计算 softmax(line 10),把结果合并起来(line 12)
重计算: 目标之一是不用给 bwd 存储O(N^2)的中间值,即 S、P。bwd里通常是需要他们来计算相对于 Q、K、V的梯度的。然而,通过存储O和 softmax 里计算所需的 (m, l)统计值,可以很容易根据 SRAM 里的 Q、K、V(是要在bwd时重新load进去的吧?只不过比load S和P要大小小) 来重计算出上述的 S、P。这可以看为类似于选择性左 gradient checkpointing。虽然它一般目的是减少峰值的显存占用,所有实现都是要牺牲速度来换显存。但是对比起来,我们的重计算方案却尽管FLOPS多了,但是bwd过程却加速了,因为减少了 HBM 仿存 (Fig. 2)。bwd的完整描述在附录B
实现细节: Kernel 融合。 Tiling 让我们可以用一个 CUDA kernel实现算法,把 input 从 HBM 里加载进来,执行所有的计算步骤(矩阵乘,softmax,选择性 masking,dropout,matrix multiply),然后把结果写回 HBM (masking和 dropout在附录 B)。这样避免从 HBM 读取输入,写入输出
-
为什么分片的大小,K、V 里面是Bc,是 M/4d ? 大概就是几十到上百行
-
为什么分片时分两组?Q,l、m;K、V是一组,大小为Br,是min(Bc, d)?
-
为什么先循环列,再循环行?
-
这里说 Block size 越大越好,如下图2里中间的图,越大,仿存次数越小。但是看起来blocksize=ceil(M/4d
图2:
左边:GPT-2 medium 上的 FWD+BWD 运行时对比(seq len 1024, head dim 64, 16 heads, batch size 64) 在A100 上。确实 FA 的仿存次数减少了9倍,耗时减少了6倍!
中间:A100 上不同 block size 下, FA 的 FWD 运行耗时和 HBM的访问量(GB? 不是次数,是大小?),seq length 1024,head dim 64,16 heads,bs 64。当 block size 越大,HBM 访问次数更少,但是有一个上限.
右边:seq length 4k 上的 block-sparse FA的加速比例,和稀疏度成正比
渐近地提高所有 SRAM 大小上的 HHBM 仿存,证明在附录 C:
定理2:让 N 是序列长度,d是 head dim,M 是 SRAM 的大小,那么 d 远小于 M,远小于 Nd。标准的算法需要 O(Nd+N^2) HBM 访问,而 fa 需要 O(N^2*d^2*M^-1) 次。而其中 d 远小于M,所以 fa 的访问次数小于 O(N^2)。
对于典型的 d(64-128),M(约为 100KB),d^2 是比M小很多的。
算法1: FA
尽管利用的技术(tiling & recomputation)都是已知的,但是依然有空间(2-4x)来加速 attentioin。需要使用 softmax decomposition。我们认为有两个原因导致虽然FLOPS增加(重计算),但是加速了:
- softmax decomposition with scaling 虽然被很多 ML 算法人员熟知,但没被很多系统研究员知道,虽然 operation fusion/memory IOs 减少是系统/编译器社区的常备,但是对算法工程师而言却不熟悉。在 F 里,即需要 softmax decomposition,也需要 operation fusion 来达到加速和节省显存的效果
- flash attention 里,算法是没变的,所以模型精度不受影响,甚至能提升(因为可以训练更长的序列)
附录 E.5 (NVfuser 1.12, AOT compiler from functorch, and TVM)里添加了自动做op fusion的baseline。Megatron 的比 这些快,而 flash 比他们还快。因为 softmax 被分解为了 local softmax。希望编译器领域的进展可以让未来加速/fusion成为可能。我们也在和编译器研究员一起自动化这些技术
与 Self-Attention Does Not Need O(n^2) Memory 之间的区别?(附录 B.5)。他们聚焦在减少总共的显存需求,利用的通用的 gradient checkpointing 技术,FlashAttention 简化了 backward (B.2, B.4),获得了加速,节省了显存
与 FLAT 区别:它缺少 softmax decompossion 技术 (L151),因此需要在整个行(或者多行)上计算 softmax。因此,flat 需要定制的有更大 sram 的硬件来放整个 key 序列,这个在 GPU 上不实用
要想让 automatic fusion 能把 softmax 也 fuse 了,需要先把 softmax 操作改为 decomposed 版本,再走 fusion 过程。希望未来编译器能支持这个特性。
重计算的开销:很快,因为输入都在 SRAM 里(为啥?),因为是作为 backward kernel 里的一部分。图2左侧展示了 pytorch 和 flash 的实现区别:虽然因为重计算多了13%的 FLOPs,但是因为减少了9.1倍的 IO,导致有 5.7倍的加速 F 对 batchsize 变化不敏感。比如 16,32,64上,都是 2-3x 的比 Megatron-LM 要快
MLPerf 里的最新实现是 vendor 里软硬件组花了6个月的实现。但我们的依然比他们的快
FA的实现需要至少一定数量的 SRAM。在 T4 上测试过。其他非 GPU 加速器有比 GPU 更多的 SRAM(比如 TPUv4 有128M SRAM,Graphcore 有 1GB SRAM,而 A100里只有19MB)
- 他们在跑 GPT 等时,用的pytorch?代码有吗?
- MB15里能用吗?比如最大长度1000*1000 = 1000K
- Long range arena: A benchmark for efficient transformers
- Efficient transformers: A survey.
- Data movement is all you need: A case study on optimizing transformers
- Memory hierarchy design. Computer Architecture: A Quantitative Approach(2003)
- Roofline: an insightful visual performance model for multicore architectures(2009). E.7 里有
- E.5 里有与 TVM 的对比
- openreview
- 在 PyTorch 里的实现: 没有用 triton,用的cutblass
- GPT or VIT 里的 flashattention 例子