在上文中,我们介绍了手写实现矩阵乘法的方法。本文我们将介绍如何优化矩阵乘法的性能。矩阵乘法有很多优化算法,本文中我们着重介绍两种算法:共享内存缓存块和一维 Thread Tile 并行优化。
在全局内存之外,GPU 还有一块位于芯片上的较小区域,被称为共享内存(SMEM)。每个 SM(流多处理器)都配备了一块共享内存。
以下是 A100 GPU 内存层次结构的图示:

picture 0
从逻辑上看,共享内存在各个块之间进行了分区。这意味着一个线程可以通过共享内存块与同一块内的其他线程进行通信。共享内存的大小是可配置的,可以通过权衡以获得更大的共享内存而减小 L1 缓存的大小。
对于这个新的内核,我们将 A 和 B 的全局内存一块加载到共享内存中。接着,我们将在这两块上尽可能多地执行计算。这样做的好处是,我们可以减少对全局内存的访问次数,因为共享内存的访问速度比全局内存快得多。
计算的流程如下图所示,可以看到我们将 A 和 B 的一块加载到共享内存中,然后在共享内存中进行计算。

picture 1
我们还是延续了上一篇文章中的矩阵乘法的实现,只是在内核中加入了共享内存的使用。每一个线程负责计算 C 中的一个元素。
以下是代码的重要部分,其中变量名称对应上面的图表:
// 推进指针到起始位置A += cRow * BLOCKSIZE * K; // 行=cRow,列=0B += cCol * BLOCKSIZE; // 行=0,列=cColC += cRow * BLOCKSIZE * N + cCol * BLOCKSIZE; // 行=cRow,列=cColfloat tmp = 0.0;// 外部循环推进 A 沿列和 B 沿行,直到我们完全计算出 C 中的结果。for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) { // 每个线程从全局内存加载 A 和 B 中的一个元素到共享内存中。 // 将 threadCol(=threadIdx.x)设为连续的索引,以允许全局内存访问协同。 As[threadRow * BLOCKSIZE + threadCol] = A[threadRow * K + threadCol]; Bs[threadRow * BLOCKSIZE + threadCol] = B[threadRow * N + threadCol]; // 阻塞本块内的线程,直到缓存完全填充 __syncthreads(); // 在当前缓存块上执行点积 for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) { tmp += As[threadRow * BLOCKSIZE + dotIdx] * Bs[dotIdx * BLOCKSIZE + threadCol]; } // 在最后需要再次同步,以避免更快的线程在较慢的线程完成之前将下一个块提取到缓存中 __syncthreads(); // 推进指针到下一个块 A += BLOCKSIZE; B += BLOCKSIZE * N;}C[threadRow * N + threadCol] = tmp;
对于初学者来说,代码中对于矩阵的索引可能有些难以理解,我们可以结合图来理解。代码中首先将 A、B、C 的指针推进到当前块的起始位置。也就是图中对应的 &A、&B、&C。以 A 举例,A += cRow * BLOCKSIZE * K,其中 cRow 为当前块的行索引,BLOCKSIZE 为块的大小,K 为矩阵 A 的列数。这样就将 A 的指针推进到当前块的起始位置。
接着,我们将 A、B 的数据读取到共享内存中。这里我们使用了二维数组,As[threadRow * BLOCKSIZE + threadCol],其中 threadRow 和 threadCol 分别为当前线程在块中的行索引和列索引。这样就将 A 的数据读取到了共享内存中。同理,我们将 B 的数据也读取到了共享内存中。需要注意的是,这里我们使用了 __syncthreads() 来同步线程,这是由于不同线程之间是并行执行的。如果不同步,可能会导致某些线程读取到的数据不是最新的。这也是比较难理解的地方,虽然我们在编写代码时,是按照顺序编写的,但是实际上不同线程是并行执行的,写代码的时候需要考虑到这一点。
后面就是计算矩阵乘法的过程,可以对比一下上一篇文章中的矩阵乘法的实现。这里我们外层循环的 Step 是 BLOCKSIZE,也就是每次计算 32 个数据。内层循环的 Step 是 1,也就是每次计算一个数据。这里我们使用了 tmp 变量来保存计算的结果,最后再将 tmp 写入到 C 中。这里我们使用了 C[threadRow * N + threadCol],其中 threadRow 和 threadCol 分别为当前线程在块中的行索引和列索引。这样就将 C 的数据写入到了共享内存中。