LeetCode:字符串相乘

2020-11-21  本文已影响0人  阿臻同学

字符串相乘 - LeetCode

导入依赖

主要依赖的库有:

import math
import random

拆分、填充

N=6 时,有 i=4,2^4=16

满足 2^{3} \lt 2 \times 6 \le 2^{4}

def to_list(num1:str, num2:str) -> tuple:
    # 拆分为 list
    a = [int(i) for i in num1]
    b = [int(i) for i in num2]
    
    # 反转列表,将低阶项系数放在列表前面
    a.reverse()
    b.reverse()
    max_len = max(len(a),len(b))
    
    # 对齐使长度相等
    l = len(a)-len(b)
    zeros = [0] * abs(l) 
    if l < 0:
        a = a + zeros
    elif l > 0:
        b = b + zeros
        
    # 补充前导 0,使得长度为 2^n
    fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
    fill = [0] * fill_count
    return a+fill,b+fill

多项式表示

对于输入的两个数 AB ,将其处理成两个多项式:

A(x) = a_0 + a_1x^1 + a_2x^2 + \cdots + a_{N-1}x^{N-1} = \sum_{j=0}^{N-1}a_jx^j

B(x) = b_0 + b_1x^1 + b_2x^2 + \cdots + b_{N-1}x^{N-1} = \sum_{j=0}^{N-1}b_jx^j

最终的目标是对多项式 C(x) = A(x) \times B(x) 进行求解。

C(x) = c_0 + c_1x^1 + c_2x^2 + \cdots + c_{2N-1}x^{2N-1} = \sum_{j=0}^{2N-1}c_jx^j

傅里叶变换求解

def multiply(num1: str, num2: str) -> str:
    l = len(num1)
    a,b = to_list(num1, num2)
    # 傅里叶变换
    a_fft, b_fft = fft(a), fft(b)
    t = []
    # 对应项相乘
    for i in range(len(a_fft)):
        t.append(a_fft[i] * b_fft[i]) 
    # 逆傅里叶变换
    ans = idft(t)
    sum = 0
    # 计算多项式
    for i,r in enumerate(ans):
        # 实部四舍五入取整
        sum += int(r.real+0.5) * (10 ** i)
    return str(sum)

傅里叶变换实现

傅里叶变换与逆傅里叶变换的主要区别在于:逆傅里叶变换需要对计算的结果除以 N (并不是在递归中进行),并且在计算的过程中 \omega = \omega^{-1}

def _ft(l:list, idft = False):
    """
    基础的变换方法,通过变量控制进行dft还是idft
    
    :param bool idft: 控制进行傅里叶变换还是逆傅里叶变换
    """
    
    n = len(l)
    if n == 1:
        return l
    
    # dft 与 idft 分别处理 $\omega$
    o_n_e = -2j if idft else 2j
    o = 1
    o_n = math.e ** (o_n_e * math.pi / n)
    
    # 拆分奇偶项
    even_index = l[::2]
    odd_index = l[1::2]
    
    y_even = _ft(even_index, idft)
    y_odd = _ft(odd_index, idft)

    y = [0]*n
    for i in range(n//2):
        y[i] = y_even[i] + o * y_odd[i]
        y[i+n//2] = y_even[i] - o * y_odd[i]
        o *= o_n
    return y

def fft(l:list):
    """
    傅里叶变换
    """
    output = _ft(l)
    return output

def idft(l:list):
    """
    逆傅里叶变换
    """
    n = len(l)
    output = _ft(l,True)
    # 将计算的结果除以 $N$
    output = [i/n for i in output]
    return output

测试

multiply()方法输出的结果与自带的乘法计算结果进行比较,并输出测试结果。

def test(num1:str, num2:str):
    r = int(multiply(num1,num2))
    s = int(num1)*int(num2)
    t = 30
    print(f"{'-'*t} Test {'-'*t}")
    print(f"Test case: \n\t{num1} \n\t{num2}")
    print(f"Program output: \n\t{r}")
    print(f"Expected output: \n\t{s}")
    print(f"❌ FAILED" if r != s else "✔ OK")
    return r == s

编写测试用例

# 测试用例数
test_cases = 10
# 数据长度
INT_MAX = 1e100
for i in range(test_caese):
    num1 = str(random.randint(0, INT_MAX))
    num2 = str(random.randint(0, INT_MAX))
    test(num1, num2)

LeetCode AC 代码

class Solution:
    def to_list(self, num1:str, num2:str) -> tuple:
        a = [int(i) for i in num1]
        b = [int(i) for i in num2]
        a.reverse()
        b.reverse()
        l = len(a)-len(b)
        max_len = max(len(a),len(b))
        # 对齐使长度相等
        zeros = [0] * abs(l) 
        if l < 0:
            a = a + zeros
        elif l > 0:
            b = b + zeros
        # 补充前导 0,使得长度为 2^n
        fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
        fill = [0] * fill_count
        return a+fill,b+fill
    
    def multiply(self, num1: str, num2: str) -> str:
        l = len(num1)
        a,b = self.to_list(num1, num2)
        a_fft, b_fft = self.fft(a), self.fft(b)

        
        t = []
        _3 = []
        for i in range(len(a_fft)):
            t.append(a_fft[i] * b_fft[i]) 
        ans = self.idft(t)
        
        sum = 0
        for i,r in enumerate(ans):
            sum += int(r.real+0.5) * (10 ** i)
        return str(sum)
        
        

    def _ft(self, l:list, idft = False):
        n = len(l)
        if n == 1:
            return l
        o_n_e = -2j if idft else 2j
        even_index = l[::2]
        odd_index = l[1::2]
        o = 1
        o_n = math.e ** (o_n_e * math.pi / n)
        
        y_even = self._ft(even_index, idft)
        y_odd = self._ft(odd_index, idft)
        
        y = [0]*n
        for i in range(n//2):
            y[i] = y_even[i] + o * y_odd[i]
            y[i+n//2] = y_even[i] - o * y_odd[i]
            o *= o_n
        return y
    
    def fft(self, l:list):
        output = self._ft(l)
        return output
    
    def idft(self, l:list):
        n = len(l)
        output = self._ft(l,True)
        output = [i/n for i in output]
        return output
上一篇下一篇

猜你喜欢

热点阅读