在初级系列中我们已经实现了一个简单的矩阵乘法的 kernel,并使用共享内存和一维线程块来优化了矩阵乘法的性能。在 GEMM 优化专栏里面,我们将会继续优化矩阵乘法的性能,这一节我们将会使用二维线程块来优化矩阵乘法的性能。

1. 一维 Thread Tile

在介绍二维 Thread Tile 之前,我们先来回顾一下一维 Thread Tile 的优化方法。在初级系列中,我们使用了一维线程块来优化矩阵乘法的性能,我们将矩阵乘法的计算任务分配给了一维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。

还记得一维 Thread Tile 中的例子吗?如果输入的 A 和 B 都是 7x7 的矩阵:

  1. 如果我们一次读取 1 行 A 和 1 列 B,当每一个线程只计算一个结果的时候,我们需要从 A 中读取 7 个数据,从 B 中读取 7 个数据,从 C 中读取 1 个数据,然后写 1 次 C。这样的话,每个线程需要读取 15 个数据,写 1 次数据。计算 16 个结果需要 16 个线程,共 16x16 = 256 次 IO。
  2. 如果我们一次读取 4 行 A 和 1 列 B,那么每一个线程计算 4 个结果,此时需要从 A 中读取 4x7 个数据,从 B 中读取 7 个数据,从 C 中读取 4 个数据,然后写 4 次 C。计算 16 个结果需要 4 个线程,共 4x43 = 172 次 IO。
  3. 如果我们一次读取 4 行 A 和 4 列 B,那么每一个线程计算 16 个结果,此时需要从 A 中读取 4x7 个数据,从 B 中读取 4x7 个数据,从 C 中读取 16 个数据,然后写 16 次 C。计算 16 个结果一共需要 1 个线程,共 1x88 = 88 次 IO。

上述的 2 就是一维 Thread Tile 优化,上述的 3 就是 二维 Thread Tile 优化,计算结果不变的同时,减少 IO 次数,提升算法的执行时间。所以想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。

2. 二维 Thread Tile

2.1 优化思路

本文的主要优化思路就是让每个线程计算一个 8*8 的网格。下面我们来看一下这个 Kernel 的主题思路图:

picture 1

picture 1

首先在内核的第一阶段, 所有线程协同工作, 从全局内存中加载矩阵 A 和矩阵 B 到共享内存中。

当 SMEM 缓存填充完毕后,每个线程负责将其相关的 SMEM 条目相乘,并将结果累加到本地寄存器中。可以看到, 每个线程计算的是一个 TM * TN 的矩阵块。如果图中的 TN 是 1, 那么就是一维 Thread Tile。

2.2 代码实现

接下来让我们动手实现这个内核, 我们按照上面的原理图来写代码。

写 Kernel 代码的时候, 适当画图是非常有帮助的。这样可以帮助我们更好的理解 Kernel 的执行流程。

首先我们需要定义一些常量方便后续使用:

// Block 索引const uint c_row = blockIdx.y;const uint c_col = blockIdx.x;// Thread 索引const uint thread_col = threadIdx.x % (BN / TN);const uint thread_row = threadIdx.x / (BN / TN);// 二维 tile (block tile) 的大小const uint total_results_block_tile = BM * BN;// 一个 block tile 需要的线程数量const uint number_threads_block_tile = total_results_block_tile / (TM * TN);assert(number_threads_block_tile == blockDim.x);// 计算该 Thread 负责加载的共享内存索引const uint inner_row_A = threadIdx.x / BK;const uint inner_col_A = threadIdx.x % BK;// 计算每个线程块一次加载的行数const uint stride_A = number_threads_block_tile / BK;const uint inner_row_B = threadIdx.x / BN;const uint inner_col_B = threadIdx.x % BN;const uint stride_B = number_threads_block_tile / BN;