Posts GEMM算子优化
Post
Cancel

GEMM算子优化

GEMM算子优化

GEMM优化的核心在于:

  1. 提高Cache命中率,设计更好的数据排布(Tiling)
  2. 提高并行度,充分利用指令向量化和多核并行

先看一个native的实现:

1
2
3
4
5
6
7
8
9
10
void naive_row_major_sgemm(const float* A, const float* B, float* C, const int M,
                           const int N, const int K) {
    for (int m = 0; m < M; ++m) {
        for (int n = 0; n < N; ++n) {
            for (int k = 0; k < K; ++k) {
                C[m * N + n] += A[m * K + k] * B[k * N + n];
            }
        }
    }
}

其浮点运算量为2xMxNxK,缓存读次数为2xMxNxK,写次数为MxN;但是这种写法会造成大量的Cache miss: img

简单的循环重排就可以得到加速:

1
2
3
4
5
6
7
8
9
10
void optimize_row_major_sgemm(const float* A, const float* B, float* C, const int M,
                              const int N, const int K) {
    for (int m = 0; m < M; ++m) {
        for (int k = 0; k < K; ++k) {
            for (int n = 0; n < N; ++n) {
                C[m * N + n] += A[m * K + k] * B[k * N + n];
            }
        }
    }
}

进一步提高性能,需要做矩阵A/B的分块计算,以CUDA而言:需要设计从Global Memory->Shared Memory->Register的分块逻辑,核心是提高Cache命中率。

img

引入共享内存缓存子矩阵块,可以减少对全局内存的访问:

  • 每个线程块负责计算 C 的一个子矩阵块。
  • 每次从全局内存加载一个子块到共享内存。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#define TILE_M 16
#define TILE_N 16
#define TILE_K 16

__global__ void gemm_tiled(float *A, float *B, float *C, int M, int N, int K) {
    // 定义共享内存
    __shared__ float tile_A[TILE_M][TILE_K];
    __shared__ float tile_B[TILE_K][TILE_N];

    // 计算当前线程块的位置
    int block_row = blockIdx.y;
    int block_col = blockIdx.x;

    // 计算当前线程的位置
    int thread_row = threadIdx.y;
    int thread_col = threadIdx.x;

    // 结果存储到寄存器
    float value = 0.0f;

    // 遍历所有子块
    for (int k = 0; k < (K + TILE_K - 1) / TILE_K; ++k) {
        // 加载 A 的子块到共享内存
        if (block_row * TILE_M + thread_row < M && k * TILE_K + thread_col < K) {
            tile_A[thread_row][thread_col] = A[(block_row * TILE_M + thread_row) * K + (k * TILE_K + thread_col)];
        } else {
            tile_A[thread_row][thread_col] = 0.0f;
        }

        // 加载 B 的子块到共享内存
        if (k * TILE_K + thread_row < K && block_col * TILE_N + thread_col < N) {
            tile_B[thread_row][thread_col] = B[(k * TILE_K + thread_row) * N + (block_col * TILE_N + thread_col)];
        } else {
            tile_B[thread_row][thread_col] = 0.0f;
        }

        __syncthreads();

        // 从共享内存加载到寄存器并计算
        for (int t = 0; t < TILE_K; ++t) {
            value += tile_A[thread_row][t] * tile_B[t][thread_col];
        }

        __syncthreads();
    }

    // 写回全局内存
    if (block_row * TILE_M + thread_row < M && block_col * TILE_N + thread_col < N) {
        C[(block_row * TILE_M + thread_row) * N + (block_col * TILE_N + thread_col)] = value;
    }
}

另外,为了减少循环控制的开销可以手动展开循环提升性能:假设矩阵维度是循环展开大小的倍数,将循环迭代分为固定大小的展开步长。

1
2
3
4
5
6
7
8
9
10
11
#define UNROLL_FACTOR 4  // 展开因子
// 原始实现:从共享内存加载到寄存器并计算
//        for (int t = 0; t < TILE_K; ++t) {
//            value += tile_A[thread_row][t] * tile_B[t][thread_col];
//        }
for (int t = 0; t < TILE_K; t += UNROLL_FACTOR) {
    value += tile_A[thread_row][t] * tile_B[t][thread_col];
    value += tile_A[thread_row][t + 1] * tile_B[t + 1][thread_col];
    value += tile_A[thread_row][t + 2] * tile_B[t + 2][thread_col];
    value += tile_A[thread_row][t + 3] * tile_B[t + 3][thread_col];
}
This post is licensed under CC BY 4.0 by the author.