Transformer 模型的时候处理速度较慢,且会占用大量的显存。自注意力的时间和内存复杂度是与序列长度的平方成正比。在 GPU 上,计算速度已超过内存速度,并且 Transformers 中的大多数操作都受到内存访问的瓶颈。因此,内存访问模式的优化是加速 Transformer 模型的关键。Flash Attention 是一种高效的自注意力实现,它通过将内存访问模式与计算结合起来,减少了内存带宽的使用,从而提高了性能。
在这个系列中,我们将介绍 Flash Attention 系列的原理和实现。
老规矩,我们还是先回顾一下 GPU 的层次结构。GPU 内存层次结构由不同大小和速度的多种形式的内存组成,较小的内存速度较快。例如,A100 GPU 具有 40-80GB 的高带宽内存(HBM),带宽为 1.5-2.0TB/s,并且每个 108 个流处理器有 192KB 的片上 SRAM,其带宽估计约为 19TB/s [44, 45]。片上 SRAM 的速度比 HBM 快一个数量级,但其大小小了多个数量级。

picture 0
给定输入序列 Q, K, V ∈ ℝN × d,其中 N 是序列长度,而 d 是头部维度,我们希望计算注意力输出 O ∈ ℝN × d:
S = QK⊤ ∈ ℝN × N, P = softmax (S) ∈ ℝN × N, O = PV ∈ ℝN × d,

picture 4
在标准的注意力机制实现中,矩阵 S 和 P 需要被显式地存储在高速但容量有限的高带宽内存(HBM)中。这种存储方式带来了 O(N2) 的内存开销,这在处理大规模输入时尤其值得关注。
以一个具体实例来看,在 GPT-2 模型中,序列长度 N 为 1024,而每个特征的维度 d 仅为 64,即 N ≫ d。由于注意力机制的核心操作(如 softmax 函数)大多受限于内存访问速度,对 HBM 的高频访问不仅增加了内存带宽压力,还显著延长了计算的整体墙钟时间(wall-clock time),从而降低了模型的运行效率。
FlashAttention 的核心思想可以用两个关键词来概括:分块计算 和 动态重计算。这两种技术的结合,使得注意力机制在保持高效的同时,显著减少了内存占用。
传统的 Softmax 需要一次性加载整个输入数据,才能计算全局的最大值和归一化系数。而 FlashAttention 采用了 增量式计算 的方式,将输入数据分成小块,依次加载到 GPU 的片上缓存(SRAM)中。
我们首先定义一些变量方便后续的讨论:
| 变量 | 尺寸(shape) | 说明 |
|---|---|---|
| Q, K, V | N × d | 输入矩阵 |
| Qi | Br × d | Q 的第 i 个行分块 |
| Kj, Vj | Bc × d | K, V 的第 j 个行分块 |
| Sij | Br × Bc | 局部注意力分数矩阵 |
| m̃ij | Br | 局部行最大值向量 |
| $\tilde{\mathbf{P}}_{ij}$ | Br × Bc | 局部未归一化的注意力权重 |
| $\tilde{\ell}_{ij}$ | Br | 局部行和向量 |
| minew | Br | 更新后的全局行最大值 |
| ℓinew | Br | 更新后的全局行和 |
| Oi | Br × d | 输出的第 i 个分块 |
首先,FlashAttention 将输入矩阵 Q, K, V 划分为若干小块。假设片上缓存的大小为 M,则 Q 被划分为 Tr = ⌈N/Br⌉ 个块,每块大小为 Br × d;K 和 V 被划分为 Tc = ⌈N/Bc⌉ 个块,每块大小为 Bc × d。这里 Br 和 Bc 的选择基于缓存的大小和特征维度 d。