使用shared memory优化矩阵乘法
2023-11-30 本文已影响0人
leon_tly
cuda矩阵乘法基础版本
使用shared meomory优化矩阵乘法
分块矩阵乘法
其中A、B、C、D、E、F、G、H为分块的小矩阵
利用 shared memory 做分块矩阵乘法(特殊场景)
- 假设前提
- 假设分块矩阵的大小为 2 x 2,block size也为 2 x 2
- 假设矩阵都是方阵,大小为 N x N
- 算法解释
- 使用shared memory存储分块矩阵的值
- 计算小矩阵的乘积(每个线程只计算一个值)
- 累加小矩阵的乘积
-
赋值给对应的结果矩阵
cuda-tiled-matrix_mul.gif
cuda-tiled-matrix_mul-1.gif
- cuda代码
__global__ tile_matrix_multiply_shared(float* A, float* B, float* C, int width)
{
__shared__ shareA[2][2];
__shared__ shareB[2][2];
int bx = blockIdx.x;
int by = blockIdx.y;
int tx = threadIdx.x;
int ty = threadIdx.y;
int col = bx * 2 + tx;
int row = by * 2 + ty;
float temp = 0;
for(int i = 0; i < width/2; ++i)
{
shareA[ty][tx] = A[row*width + (i*2 + tx)];
shareB[ty][tx] = B[(i*2 + ty)*width + col];
__syncthreads();
for(int k = 0; k < 2; ++k){
temp += shareA[ty][k] * shareB[k][tx];
__syncthreads();
}
}
C[row*width + col] = temp;
}
利用 shared memory 做分块矩阵乘法(推广到一般场景)
#define BLOCK_SIZE 16
// P : m x n
// Q : n x k
// R : m x k
template<typename T>
__global__ tile_matrix_multiply_shared(T* P, T* Q, T*R , int m ,int n, int k)
{
// Tile size to store elements in shared memory
// 定义 P 矩阵的小矩阵块
__shared__ T ds_p[BLOCK_SIZE][BLOCK_SIZE];
// 定义 P 矩阵的小矩阵块
__shared__ T ds_q[BLOCK_SIZE][BLOCK_SIZE];
int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdy.y;
float r_value = 0;
// global memory to share memory
for (int i = 0; i < ((m + BLOCK_SIZE - 1) / BLOCK_SIZE); i++)
{
if (iy < m && i * BLOCK_SIZE = threadIdx.x < n)
{
ds_p[threadIdy.y][threadIdx.x] = P[iy * n + i * BLOCK_SIZE + threadIdx.x];
}
else
{
ds_p[threadIdy.y][threadIdx.x] = 0;
}
if (ix < k && i * BLOCK_SIZE = threadIdx.y < n)
{
ds_q[threadIdy.y][threadIdx.x] = Q[(i*BLOCK_SIZE +threadIdy.y)*k + ix];
}
else
{
ds_q[threadIdy.y][threadIdx.x] = 0;
}
__syncthreads();
//calculate a partial value of thread element in C
for (int c = 0; c < BLOCK_SIZE ; ++c)
{
elementC += ds_p[threadIdy.y][i] * ds_q[c][threadIdx.x];
}
__syncthreads();
}
//copy final element value to the C matrix
if (row < m && col < k)
{
C[row*k+col] = elementC;
}
}
备注:按照我目前的只是储存,cuda所开启的线程都是和目标矩阵所对应的(忽略超过边界的线程)。