在初级系列中我们已经实现了一个简单的矩阵乘法的 kernel,并使用共享内存和一维线程块来优化了矩阵乘法的性能。在 GEMM 优化专栏里面,我们将会继续优化矩阵乘法的性能,这一节我们将会使用二维线程块来优化矩阵乘法的性能。
在介绍二维 Thread Tile 之前,我们先来回顾一下一维 Thread Tile 的优化方法。在初级系列中,我们使用了一维线程块来优化矩阵乘法的性能,我们将矩阵乘法的计算任务分配给了一维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。
还记得一维 Thread Tile 中的例子吗?如果输入的 A 和 B 都是 7x7 的矩阵:
上述的 2 就是一维 Thread Tile 优化,上述的 3 就是 二维 Thread Tile 优化,计算结果不变的同时,减少 IO 次数,提升算法的执行时间。所以想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。
本文的主要优化思路就是让每个线程计算一个 8*8 的网格。下面我们来看一下这个 Kernel 的主题思路图:

picture 1
首先在内核的第一阶段, 所有线程协同工作, 从全局内存中加载矩阵 A 和矩阵 B 到共享内存中。
当 SMEM 缓存填充完毕后,每个线程负责将其相关的 SMEM 条目相乘,并将结果累加到本地寄存器中。可以看到, 每个线程计算的是一个 TM * TN 的矩阵块。如果图中的 TN 是 1, 那么就是一维 Thread Tile。
接下来让我们动手实现这个内核, 我们按照上面的原理图来写代码。
写 Kernel 代码的时候, 适当画图是非常有帮助的。这样可以帮助我们更好的理解 Kernel 的执行流程。
首先我们需要定义一些常量方便后续使用:
// Block 索引const uint c_row = blockIdx.y;const uint c_col = blockIdx.x;// Thread 索引const uint thread_col = threadIdx.x % (BN / TN);const uint thread_row = threadIdx.x / (BN / TN);// 二维 tile (block tile) 的大小const uint total_results_block_tile = BM * BN;// 一个 block tile 需要的线程数量const uint number_threads_block_tile = total_results_block_tile / (TM * TN);assert(number_threads_block_tile == blockDim.x);// 计算该 Thread 负责加载的共享内存索引const uint inner_row_A = threadIdx.x / BK;const uint inner_col_A = threadIdx.x % BK;// 计算每个线程块一次加载的行数const uint stride_A = number_threads_block_tile / BK;const uint inner_row_B = threadIdx.x / BN;const uint inner_col_B = threadIdx.x % BN;const uint stride_B = number_threads_block_tile / BN;