在上一篇中,我们介绍了 Flash Attention v1 的基本原理和分块计算的思想。本文将深入探讨 Flash Attention v1 的实现细节。
的是 Flash Attention 的 CUDA 内核接收三个输入张量 Q(Q)、K(键)、V(值),计算 QKT,然后经过 softmax 归一化得到注意力权重,最终与 V 相乘得到输出 O。在这个过程中,还维护了中间状态 l 和 m(分别代表累积指数和归一化系数),以便分块累积计算长序列时保持数值稳定性。总体上,每个线程块(block)负责处理一个 batch 的一个 head 内的一部分数据,利用共享内存(sram)来减少全局内存访问延迟。整个内核通过嵌套两层循环来实现对大矩阵分块计算的过程。
核函数在启动前,主机代码首先要确定每个线程块(block)需要使用的共享内存大小。这里的计算公式为:
const int sram_size = (3 * Bc * D * sizeof(float)) + (Bc * Br * sizeof(float));
这段代码由两部分组成:
(3 * Bc * D * sizeof(float)) 此处 3 代表共享内存中划分出来的三个区域:Qi、Kj 和 Vj。Bc 表示每个 block 中需要加载的元素个数(其实和线程数有关,每个线程负责加载 1 组 d 元素),D 就是每个向量的维度,也就是每个线程加载的数据条数。 sizeof(float) 是每个浮点数的字节数(通常为 4 字节)。 整体来看,这部分计算出的是存储 Qi、Kj 和 Vj 这三个数据块所需要的共享内存总字节数。
(Bc * Br * sizeof(float)) 此部分对应共享内存中 S 区域,用于存储中间计算结果。 Br 则通常代表内层循环中维度的大小,同样乘上 Bc 与 float 的字节数,得到对应的共享内存所需大小。
将两个部分相加,就得到了每个 block 所需的共享内存总量 sram_size。如此计算可以确保在调用内核时把共享内存传递进去,从而保证内核中的动态共享内存可以正确使用。
为了避免请求的共享内存超过设备的最大允许值,程序调用了:
int max_sram_size;cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);
如果请求的 sram_size 超过 max_sram_size,那么内核启动时将会失败,这时候我们需要调整 Bc、D、Br 的参数,找到平衡点,既能保证算法所需内存,又不会超过硬件限制。
这里为了简单起见,在代码中直接将 Bc 和 Br 写成了固定值。值得注意的是,这个 Br 和 Bc 的值是可以不一样的,并且一定有Br ≤ Bc。
const int Bc = 32;const int Br = 16;
至于为什么一定会有Br ≤ Bc,则可以回到在 Flash Attention V1 的论文里,其计算方式为 $Bc=\lceil \frac{M}{4d} \rceil$,$Br= min(\lceil \frac{M}{4d} \rceil, d)$。其中M是设备每个 SM 所能使用的最大共享内存空间大小,d是每个向量的维度。4d表示的是 Q,K,V,S 使用共享内存的子块大小之和。这里会发现当$\lceil \frac{M}{4d} \rceil > d$时,$Br = d < \lceil \frac{M}{4d} \rceil = Bc$。当$\lceil \frac{M}{4d} \rceil \leq d$时,$Br = \lceil \frac{M}{4d} \rceil = Bc$。所以一定会有Br ≤ Bc。
根据这个性质,当 Br 与 Bc 不相等时时,也可以只用简单的 if 语句就可以完成 Q 子块的加载,但设置 Bc 和 Br 的时候最好是相等的,可以提高 GPU 线程的利用率。
接下来,我们需要设置 CUDA 内核的执行维度(也就是 gridDim 和 blockDim):
dim3 grid_dim(B, nh); // B: batch 大小,nh: head 数dim3 block_dim(Bc); // 每个块内有 Bc 个线程
网格(grid)的维度设置为 dim3(B, nh),第一个维度 B 对应批量(batch)大小,每个 batch 分配一个 block 行。第二个维度 nh 则对应多头注意力中的 head 数,每个 head 分配到不同的 block 列。这样就保证了在一个 kernel 启动中,不同 batch 与不同 head 可以并行执行而互不影响。