cython的使用加速计算

2019-01-14  本文已影响15人  LuDon

Python的使用具有以下优点:

例子:

# dot_python.py
import numpy as np

def naive_dot(a, b):
    if a.shape[1] != b.shape[0]:
        raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
        for j in xrange(m):
            s = 0
            for k in xrange(p):
                s += a[i, k] * b[k, j]
            c[i, j] = s
    return c

矩阵是保存在numpy.ndarray中的,计算速度比c/c++慢很多,使用cython实现:

# dot_cython.pyx
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.float32_t, ndim=2] _naive_dot(np.ndarray[np.float32_t, ndim=2] a, np.ndarray[np.float32_t, ndim=2] b):
    cdef np.ndarray[np.float32_t, ndim=2] c
    cdef int n, p, m
    cdef np.float32_t s
    if a.shape[1] != b.shape[0]:
        raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
        for j in xrange(m):
            s = 0
            for k in xrange(p):
                s += a[i, k] * b[k, j]
            c[i, j] = s
    return c

def naive_dot(a, b):
    return _naive_dot(a, b)

看起来和Python的差不多,区别在于:

 cdef int my_min(int x, int y):
     return x if x <= y else y

cython程序需要先编译在调用,流程如下

# setup.py
from distutils.core import setup, Extension
from Cython.Build import cythonize
import numpy
setup(ext_modules = cythonize(Extension(
    'dot_cython',   # 动态链接库的名字
    sources=['dot_cython.pyx'],  
    language='c',
    include_dirs=[numpy.get_include()],
    library_dirs=[],
    libraries=[],
    extra_compile_args=[],
    extra_link_args=[]
)))

然后需要执行以下命令就可以把cython程序编译成动态链接库了

python setup.py build_ext --inplace

运行完之后,当前目录就多出了dot_cython.c和dot_cython.so。前者是生成的c程序,后者是编译好了的动态链接库。

上一篇下一篇

猜你喜欢

热点阅读