4.2 分治策略.矩阵乘法的Strassen算法
如果你以前曾经接触过矩阵,可能了解如何进行矩阵乘法(否则,请阅读D.1节)。若A= (aij) 和B=(bij)是n*n的方阵,则对i, j=1, 2, ... , n, 定义乘积C=A*B中的元素cij为:
(4.8)我们需要计算n^2个矩阵元素,每个元素是n个值的和。下面过程接收n*n矩阵A和B,返回它们的乘积——n*n矩阵C。假设每个矩阵都有一个属性rows,给出矩阵的行数。
过程SQUARE-MATRIX-MULTIPLY工作过程如下。第3~7行的for循环计算每行中的元素,第i行中,第4~7行的for循环计算每列中的元素cij。第5行将cij初始化为0,开始公式(4.8)中的求和计算,第6~7行的for循环的每步迭代将公式(4.8)中的每一项累加进来。
由于三重for循环的每一重都恰好执行n步,而第7行每次执行都花费常量时间,因此过程SQUARE-MATRIX-MULTIPLY花费THETA(n^3)时间。
你最初可能认为任何矩阵乘法都要花费OMEGA(n3)时间,因为矩阵乘法的自然定义就需要进行这么多次的标量乘法。但这是错误的:我们有方法在o(n3)时间内完成矩阵乘法。在本节中,我们将看到Strassen的著名n*n矩阵相乘的递归算法。我们将在4.5节证明其运行时间为THETA(nlg7)。由于lg7在2.80和2.81之间,因此,Strassen算法的运行时间为O(n2.81),渐近复杂性优于简单的SQUARE-MATRIX-MULTIPLY过程。
一个简单的分治算法
为简单起见,当使用分治算法计算矩阵积C=A*B时,假定三个矩阵均为n*n矩阵,其中n为2的幂。我们做出这个假设是因为在每个分解步骤中,n*n矩阵都被划分为4个n/2*n/2的子矩阵,如果假定n是2的幂,则只要n>=2即可保证子矩阵规模n/2为整数。
假定将A、B和C均分解为4个n/2*n/2的子矩阵: (4.9)
因此可以将公式C=A*B改写为:
(4.10)公式(4.10)等价于如下4个公式:
每个公式对应两对n/2*n/2矩阵的乘法及n/2*n/2积的加法。我们可以利用这些公式设计一个直接的递归分治算法:
这段伪代码掩盖了一个微妙但重要的实现细节。在第5行应该如何分解矩阵?如果我们真的创建12个新的n/2*n/2矩阵,将会花费THETA(n^2)时间复制矩阵元素。实际上,我们可以不必复制元素就能完成矩阵分解,其中的诀窍是使用下标计算。我们可以通过原矩阵的一组行下标和一组列下标来指明一个子矩阵。最终表示子矩阵的方法与表示原矩阵的方法略有不同,这就是我们省略的细节。这种表示方法的好处是,通过下标计算指明子矩阵,执行第5行只需THETA(1)的时间(虽然我们将看到是否通过复制元素来分解矩阵对总渐近运行时间并无影响)。
现在,我们推导出一个递归式来刻画SQUARE-MATRIX-MULTIPLY-RECURSIVE的运行时间。令T(n)表示用此过程计算两个n*n矩阵乘积的时间。对n=1的基本情况,我们只需进行一次标量乘法(第4行),因此
T(1)=THETA(1)------(4.15)
当n>1时是递归情况。如前文所讨论,在第5行使用下标计算来分解矩阵花费THETA(1)时间。第6~9行,我们共8次递归调用SQUARE-MATRIX-MULTIPLY-RECURSIVE。由于每次递归调用完成两个n/2*n/2矩阵的乘法,因此花费时间为T(n/2),8次递归调用总时间为8T(n/2)。我们还需要计算第6~9行的4次矩阵加法。每个矩阵包含(n2)/4个元素,因此,每次矩阵加法花费THETA(n2)时间。由于矩阵加法的次数是常数,第6~9行进行矩阵加法的总时间为THETA(n^2)(这里我们仍然使用下标计算方法将矩阵加法的结果放置于矩阵C的正确位置,由此带来的额外开销为每个元素THETA(1)时间)。因此,递归情况的总时间为分解时间、递归调用时间及矩阵加法时间之和:
T(n)=THETA(1)+8T(n/2)+THETA(n2)=8T(n/2)+THETA(n2)---(4.16)
注意,如果通过复制元素来实现矩阵分解,额外开销为THETA(n^2),递归式不会发生改变,只是总运行时间将会提高常数倍。
组合公式(4.15)和公式(4.16),我们得到SQUARE-MATRIX-MULTIPLY-RECURSIVE运行时间的递归式: (4.17)
我们在4.5节将会看到利用主方法求解递归式(4.17),得到的解为T(n)=THETA(n^3)。因此,简单的分治算法并不优于直接的SQUARE-MATRIX-MULTIPLY过程。
在继续介绍Strassen算法之前,让我们先回顾一下公式(4.16)的几个组成部分都是从何而来的。用下标计算方法分解每个n*n矩阵花费THETA(1)时间,但有两个矩阵需要分解。虽然你可能认为分解两个矩阵需要THETA(2)时间,但实际上THETA符号中已经包含常数2在内了。假定每个矩阵包含k个元素,则两个矩阵相加需花费THETA(k)时间。由于每个矩阵包含了(n2)/4个元素,每次加法花费THETA((n2)/4)时间。我们需要进行4次矩阵加法,再次,我们并不说花费了THETA(4n2)时间,而是THETA(n2)时间。(当然,你可能发现我们可以说4次矩阵加法花费了THETA(4n2/4)时间,而4n2/4=n2,但此处的要点是THETA符号已经包含了常数因子,无论怎样的常数因子均可省略。)因此,我们最终得到两项THETA(n2),可以将它们合二为一。
但是,当分析8次递归调用时,就不能简单省略常数因子8了。换句话说,我们必须说递归调用共花费8T(n/2)时间,而不是T(n/2)时间。至于这是为什么,你可以回顾一下图2-5中的递归树,它对应递归式(2.1)(与递归式(4.7)相同),其递归情况为T(n)=2T(n/2)+THETA(n)。因子2决定了树中每个结点有几个孩子结点,进而决定了树的每一层为总和贡献了多少项。如果省略公式(4.16)中的因子8或递归式(4.1)中的因子2,递归数就变为线性结构,而不是“茂盛的”了,树中的每一层只为总和贡献了一项。
因此,切记,虽然渐近符号包含了常数因子,但递归符号(如T(n/2))并不包含。
Strassen方法
Strassen方法对的核心思想是令递归树稍微不那么茂盛一点,即只递归进行7次而不是8次n/2*n/2矩阵的乘法。减少一次矩阵乘法带来的代价可能是额外几次n/2*n/2矩阵的加法,但只是常数次。与前文一样,当建立递归式刻画运行时间时,常数次矩阵加法被THETA符号包含在内。
Strassen算法不是那么直观(这可能是本书陈述最不充分的地方了)。它包含4个步骤:
- 按公式(4.9)将输入矩阵A、B和输出矩阵C分解为n/2*n/2的子矩阵。采用下标计算方法,此步骤花费THETA(1)时间,与SQUARE-MATRIX-MULTIPLY-RECURSIVE相同。
- 创建10个n/2*n/2的矩阵S1, S2, ... , S10,每个矩阵保存步骤1中创建的两个子矩阵的和或差。花费时间为THETA(n^2)。
- 用步骤1中创建的子矩阵和步骤2中创建的10个矩阵,递归地计算7个矩阵积P1, P2, ..., P7。每个矩阵Pi都是n/2*n/2的。
- 通过Pi矩阵的不同组合进行加减计算,计算出结果矩阵C的子矩阵C11, C12, C21, C22。花费时间THETA(n^2)。
我们稍后会看到步骤2~4的细节,但现在可以建立Strassen算法的运行时间递归式。假定一旦矩阵规模从n变为1,就进行简单的标量乘法计算,正如SQUARE-MATRIX-MULTIPLY-RECURSIVE的第4行那样。当n>1时,步骤1、2和4共花费THETA(n^2)时间,步骤3要求进行7次n/2*n/2矩阵的乘法。因此,我们得到如下描述Strassen算法运行时间T(n)的递归式:
(4.18)我们用常数次矩阵乘法(注:此处翻译有误,应为加法)的代价减少了一次矩阵乘法。一旦我们理解了递归式及其解,就会看到这种交换确实能带来更低的渐近运行时间。利用4.5节的主方法,可以求出递归式(4.18)的解为T(n)=THETA(n^lg7)。
我们现在来介绍Strassen算法的细节。在步骤2中,创建如下10个矩阵:
由于必须进行10次n/2*n/2矩阵的加减法,因此,该步骤花费THETA(n^2)时间。
在步骤3中,递归地计算7次n/2*n/2矩阵的乘法,如下所示:
注意,上述公式中,只有中间一列的乘法是真正需要计算的。右边这列只是用来说明这些乘积与步骤1创建的原始子矩阵之间的关系。
步骤4对步骤3创建的Pi矩阵进行加减法运算,计算出C的4个n/2*n/2的子矩阵,首先,
C11=P5+P4-P2+P6
利用每个Pi的展开式展开等式右部,每个Pi的展开式位于单独一行,并将可以消去的项垂直对齐,我们可以看到C11等于
与公式(4.11)相同。类似地,令
C12=P1+P2
则C12等于
与公式(4.12)相同。令
C21=P3+P4
使C21等于
与公式(4.13)相同。最后,令
C22=P5+P1-P3-P7
则C22等于
与公式(4.14)相同。在步骤4中,共进行了8次n/2*n/2矩阵的加减法,因此花费THETA(n^2)时间。
因此,我们看到由4个步骤构成的Strassen算法,确实生成了正确的矩阵乘积,递归式(4.18)刻画了它的运行时间。由于我们将在4.5节看到次递归式的解为T(n)=THETA(n^lg7),Strassen方法的渐近复杂性低于直接的SQUARE-MATRIX-MULTIPLY过程。本章注记会讨论Strassen算法实际应用方面的一些问题。
练习
注意:虽然练习4.2-3、4.2-4和4.2-5是关于Strassen算法的变形的,但你应该先阅读4.5节,然后再尝试求解这几个问题。
4.2-1 使用Strassen算法计算如下矩阵乘法:
给出计算过程。
4.2-2 为Strassen算法编写伪代码。
4.2-3 如何修改Strassen算法,使之适应矩阵规模n不是2的幂的情况?证明:算法的运行时间为 THETA(n^lg7)。
4.2-4 如果可以用k次乘法操作(假定乘法的交换律不成立)完成两个3*3矩阵相乘,那么你可以在o(n^lg7)时间内完成n*n矩阵相乘,满足这一条件的最大的k是多少?此算法的运行时间是怎样的?
4.2-5 V. Pan发现一种方法,可以用132 464次操作完成68*68的矩阵相乘,发现另一种方法,可以用143 640次乘法操作完成70*70的矩阵相乘,还发现一种方法,可以用155 424次乘法操作完成72*72的矩阵相乘。当用于矩阵相乘的分治算法时,上述哪种方法会得到最佳的渐近运行时间?与Strassen算法相比,性能如何?
4.2-6 用Strassen算法作为子进程来进行一个kn*n矩阵和一个n*kn矩阵相乘,最快需要花费多长时间?对两个输入矩阵规模互换的情况,回答相同的问题。
k*THETA(nlg7)=THETA(nlg7)
k*THETA(nlg7)+k*THETA(n2)=THETA(n^lg7)
4.2-7 设计算法,仅使用三次实数乘法即可完成复数a+bi和c+di相乘。算法需接收a、b、c和d为输入,分别生成实部ac-bd和虚部ad+bc。
A=(a-b)d
B=(c-d)a
C=(d+c)*b
A+B=ac-bd
A+C=ad+bc