Strassen矩阵乘法
2019-01-23 本文已影响0人
ZakWind
一、思路
假设n是2的幂。将矩阵A,B和C中每一矩阵都分块成4个大小相等的子矩阵,每个子矩阵都是的方阵。由此可将方程C=AB重写为
定义
则
时间复杂度
二、C++代码:
//C++
#include <iostream>
using namespace std;
//矩阵类
class matrix {
private:
int **mp;//矩阵数组
int n;//矩阵的阶
public:
//创建零矩阵
explicit matrix(int n) {
this->n = n;
mp = new int *[n];
for (int i = 0; i < n; ++i) {
mp[i] = new int[n];
for (int j = 0; j < n; ++j) {
mp[i][j] = 0;
}
}
}
//使用数组创建矩阵
matrix(int n, int **mp) {
this->n = n;
this->mp = new int *[n];
for (int i = 0; i < n; ++i) {
this->mp[i] = new int[n];
for (int j = 0; j < n; ++j) {
this->mp[i][j] = mp[i][j];
}
}
}
//以矩阵A的1/4部分创建矩阵
matrix(matrix A, int p1, int p2) {
n = A.n / 2;
mp = new int *[n];
for (int i = 0; i < n; ++i) {
mp[i] = new int[n];
for (int j = 0; j < n; ++j) {
mp[i][j] = A.mp[i + (p1 - 1) * (n)][j + (p2 - 1) * (n)];
}
}
}
matrix operator+(const matrix &b) {
matrix c(this->n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
c.mp[i][j] = this->mp[i][j] + b.mp[i][j];
}
}
return c;
}
matrix operator-(const matrix &b) {
matrix c(this->n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
c.mp[i][j] = this->mp[i][j] - b.mp[i][j];
}
}
return c;
}
void show() {
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
cout << mp[i][j] << " ";
}
cout << endl;
}
}
//四个子矩阵合并成一个矩阵
void merge(matrix a11, matrix a12, matrix a21, matrix a22) {
for (int i = 0; i < n / 2; i++) {
for (int j = 0; j < n / 2; j++) {
mp[i][j] = a11.mp[i][j];
}
for (int j = n / 2; j < n; j++) {
mp[i][j] = a21.mp[i][j - n / 2];
}
}
for (int i = n / 2; i < n; i++) {
for (int j = 0; j < n / 2; j++) {
mp[i][j] = a12.mp[i - n / 2][j];
}
for (int j = n / 2; j < n; j++) {
mp[i][j] = a22.mp[i - n / 2][j - n / 2];
}
}
}
//乘法
static matrix multiply(matrix a, matrix b) {
matrix c(a.n);
if (a.n == 1) {
c.mp[0][0] = a.mp[0][0] * b.mp[0][0];
} else {
matrix a11(a, 1, 1), a12(a, 1, 2), a21(a, 2, 1), a22(a, 2, 2);
matrix b11(b, 1, 1), b12(b, 1, 2), b21(b, 2, 1), b22(b, 2, 2);
matrix m1 = multiply(a11, b12 - b22);
matrix m2 = multiply(a11 + a12, b22);
matrix m3 = multiply(a21 + a22, b11);
matrix m4 = multiply(a22, b21 - b11);
matrix m5 = multiply(a11 + a22, b11 + b22);
matrix m6 = multiply(a12 - a22, b21 + b22);
matrix m7 = multiply(a11 - a21, b11 + b12);
c.merge(m5 + m4 - m2 + m6, m1 + m2, m3 + m4, m5 + m1 - m3 - m7);
}
return c;
}
};
int main() {
int n = 4;
int array[4][4] = {{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10, 11, 12},
{13, 14, 15, 16}};
//将array转换成动态数组
int **a = new int *[n];
for (int i = 0; i < n; ++i) {
a[i] = new int[n];
for (int j = 0; j < n; ++j) {
a[i][j] = array[i][j];
}
}
//创建矩阵
matrix m(n, a);
//自己乘自己
matrix::multiply(m, m).show();
}
三、运行结果
90 202 314 426
100 228 356 484
110 254 398 542
120 280 440 600