使用 NumPy 和 SciPy 创建 PyTorch 扩展

2019-07-16  本文已影响0人  捡个七

官方教程链接: CREATING EXTENSIONS USING NUMPY AND SCIPY

该教程主要有两个任务:

使用 NumPy 实现无参数的网络

下面使用的这层网络没有做任何有用的或者数学上正确的计算,所以被称为 BadFFTFunction

# layer implementation

from numpy.fft import rfft2, irfft2

class BadFFTFunction(Function):
    
    def forward(self, input):
        numpy_input = input.detach().numpy()
        result = abs(rfft2(numpy_input))
        return input.new(result)
    
    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return grad_output.new(result)

因为这一层没有任何参数,我们可以简单地将其声明为一个函数,而不是 nn.Module 类。

def incorrect_fft(input):
    return BadFFTFunction()(input)

下面是使用该网络层的例子:

inputs = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(inputs)
print(result)
result.backward(torch.randn(result.size()))
print(inputs)

-------------------------------------------------------------------------
tensor([[ 3.5953,  2.3891,  2.8538,  6.3056,  7.1890],
        [ 6.0135, 10.8107,  4.2032,  9.4689, 10.2098],
        [ 4.6084,  4.5200,  7.8461,  5.3306, 16.6947],
        [ 1.1328,  3.6691,  5.6570, 10.1536,  1.2553],
        [ 4.9080,  3.0635,  4.9613,  5.5422, 10.7650],
        [ 1.1328, 10.7622, 11.3006, 12.5434,  1.2553],
        [ 4.6084,  9.3826,  6.1878,  3.6052, 16.6947],
        [ 6.0135,  2.6298,  4.7681,  0.3978, 10.2098]],
       grad_fn=<BadFFTFunction>)
tensor([[ 1.8835,  0.4974, -1.0209,  0.1234,  0.3349, -2.1377,  0.1967, -1.2438],
        [-0.6187, -1.3692,  1.9919, -0.6665, -0.4790, -1.1658, -1.0086,  0.0427],
        [-0.9035,  0.5733, -1.9797,  0.3805, -0.4385,  1.7815,  0.2453,  0.3710],
        [-0.5477,  0.9553, -0.7232, -0.9086, -0.7948,  0.9149,  0.4236, -0.2123],
        [-1.4582, -0.9862,  0.6265, -0.5989,  0.7842,  0.7988, -0.3591,  0.8035],
        [-0.1081,  0.4932, -0.2232,  0.5371,  0.7379, -0.5363, -0.6724, -0.0632],
        [-1.7535,  2.3054,  0.0435,  1.2096, -0.0145,  0.5476, -0.3470,  0.3916],
        [-0.5269, -0.5503,  0.2355, -0.2890,  0.0305, -0.4156,  1.0513,  0.2139]],
       requires_grad=True)

使用 SciPy 实现有参数的网络

在深度学习文献中,这一层被混淆地称为卷积,而实际操作是 cross-correlation (唯一的区别是卷积时会翻转滤波器,而 cross-correlation 不翻转)。

cross-correction 也有一个表示权值的 filter (kernel),该层也是一个具有可学习权值的层。其反向传播会计算相对于输入的梯度和相对于 filter 的梯度。

from numpy import flip
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

class ScipyConv2dFunction(Function):
    
    @staticmethod
    def forward(ctx, input, filter, bias):
        # detach so we cast to NumPy
        input, filter, bias = input.detach(), filter.detach(), bias.detach()
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        result += bias.numpy()
        ctx.save_for_backward(input, filter, bias)
        return torch.as_tensor(result, dtype=input.dtype)
    
    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input, filter, bias = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_bias = np.sum(grad_output, keepdims=True)
        grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
        # the previous line can be expressed equivalently as:
        # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1) , mode='full')
        grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
        return torch.from_numpy(grad_input), \
      torch.from_numpy(grad_filter).to(torch.float), \
      torch.from_numpy(grad_bias).to(torch.float)

使用类来实现该层网络的操作:

class ScipyConv2d(Module):
    def __init__(self, filter_width, filter_height):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(filter_width, filter_height))
        self.bias = Parameter(torch.randn(1, 1))
        
    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter, self.bias)

下面是使用该网络层的例子:

module = ScipyConv2d(3, 3)
print("Filter and bais: ", list(module.parameters()))
input = torch.randn(10, 10, requires_grad=True)
output = module(input)
print("Output from the convolution: ", output)
output.backward(torch.randn(8, 8))
print("Gradient for the input map: ", input.grad)

-----------------------------------------------------------------------------
Filter and bais:  [Parameter containing:
tensor([[ 1.0172, -0.7830, -1.9644],
        [-1.7501, -0.3380,  1.0851],
        [-0.6086,  0.5211, -0.1384]], requires_grad=True), Parameter containing:
tensor([[0.8491]], requires_grad=True)]
Output from the convolution:  tensor([[-3.3643, -1.6414,  3.8635,  5.7214,  4.2812, -0.1469,  2.2956,  4.6972],
        [ 1.0405,  5.4137, -0.2289,  3.7867, -0.8485,  1.0467,  5.0971,  0.6170],
        [ 0.3865,  7.9669,  4.7172, -5.9195,  2.6202,  4.1359, -1.2188,  4.6258],
        [-4.0765, -1.9985,  3.0376,  3.7519,  4.8408, -0.5378,  0.9233,  2.9950],
        [ 7.2145, -0.1482,  1.9535,  2.1877, -0.5471,  6.3192,  6.6404,  4.5604],
        [ 2.6525,  1.4568,  8.2622,  2.1857, -4.5970, -0.7388, -1.2843,  3.0592],
        [ 3.2907,  4.0466, -2.7943, -2.3269, -0.5543,  7.4176,  2.9281,  0.6315],
        [ 5.6153,  1.4405, -8.2649, -3.6808,  7.4088,  4.8308,  0.6125,  0.2748]],
       grad_fn=<ScipyConv2dFunctionBackward>)
Gradient for the input map:  tensor([[ 8.4448e-01, -4.6131e-01, -1.2356e+00, -2.3001e-01, -2.7439e+00,
         -9.6755e-01,  3.9761e+00,  3.8412e-01, -1.0720e+00,  1.3304e+00],
        [-2.0427e+00,  5.0312e-01, -1.3896e-01, -9.8333e-01,  3.3517e+00,
          1.8381e+00, -2.5191e+00, -1.6409e+00,  5.2481e-01, -4.0503e-01],
        [-3.4304e-03,  9.7143e-01,  8.0939e-01, -2.3209e+00, -2.4818e+00,
         -2.2358e+00,  3.3594e-01,  9.6761e-01, -8.7727e-01,  1.7346e+00],
        [ 1.2670e+00, -3.0389e+00, -1.3391e+00,  1.4903e-01,  1.7144e+00,
         -2.2407e-01,  5.4215e-01,  2.1312e+00, -2.2236e+00, -2.2285e+00],
        [ 6.0892e-01, -1.5455e+00,  3.4901e+00, -3.1687e+00, -3.5638e+00,
          5.3970e+00, -4.1608e+00, -7.5911e-01,  5.0879e+00,  2.5559e+00],
        [ 4.9064e-01,  3.2317e+00, -6.9631e+00, -4.6371e+00,  4.4206e+00,
         -6.6388e-02,  1.6657e+00,  8.6398e-01, -4.3631e+00, -6.9194e-01],
        [-1.7784e+00, -1.9765e+00, -5.0315e+00,  3.8658e+00,  1.1239e+00,
         -3.7742e+00, -2.5467e+00, -1.1219e+00, -3.4360e-01,  1.1228e+00],
        [ 4.4786e-01, -4.6717e+00, -5.5782e-01, -1.5868e-01, -8.8934e+00,
          2.3656e+00,  2.7402e+00,  4.5009e+00,  2.4637e+00, -1.5834e+00],
        [-3.2312e+00, -1.3407e+00,  2.0052e-01, -1.1472e-02,  4.3446e+00,
          3.0356e+00, -1.3052e+00, -7.6964e-01, -1.5648e+00,  6.0754e-01],
        [-1.0473e+00,  8.7615e-01, -1.1456e+00,  1.1731e+00,  5.9753e-01,
         -1.8710e-01,  1.7740e-01, -5.7756e-01,  3.6896e-01, -6.6725e-02]])

最后进行一下梯度校验:

from torch.autograd.gradcheck import gradcheck
​
moduleConv = ScipyConv2d(3, 3)
​
input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)

--------------------------------------------------------
Are the gradients correct:  True
上一篇下一篇

猜你喜欢

热点阅读