5行代码,在C语言中调用CUDA加速进行张量计算
迄今为止最懒但是最好用的方法
============================================================================================================================================================================================================================
我们先给出一个需求:
intmain(intargc,char*argv[]){intc_arr_0[] = {1,2,3,4,5,6};intc_arr_1[] = {7,8,9,10,11,12};intc_arr_2[] = {0,0,0,0,0,0};//计算c_arr_0与c_arr_1的元素乘积,代码开始//开始你的表演//代码结束,越少越好for(inti=0;i
计算过程中的需求是
要适应各种尺寸的输入数据、要支持各种各样的计算类型(加减乘除,各种向量、矩阵计算、各种张量计算还有神经网络前向后项……),上面的代码就是给了个简单的例子
GPU加速
代码量要少,超过10行就头疼
我的结果
#include"py.h"intmain(intargc,char*argv[]){intc_arr_0[] = {1,2,3,4,5,6};intc_arr_1[] = {7,8,9,10,11,12};intc_arr_2[] = {0,0,0,0,0,0};//convert c array to py list intc_shape[] = {6}; py shape = py_from_int_list(c_shape);//convert c array to torch tensor with shapepy array0 = py_from_array(c_arr_0,sizeof(c_arr_0),shape,PY_INT); py array1 = py_from_array(c_arr_1,sizeof(c_arr_1),shape,PY_INT);//call torch functions. need specify number of argumentspy array2 = py_call("torch.mul",2,array0,array1);//convert back and displaypy_to_array(array2,c_arr_2,sizeof(c_arr_2));for(inti=0;i
上面的代码非常明快,而且符合人类的基本认知:我们的目标是完成数学计算,没必要在这个过程中学习CUDA、OpenCL等一大堆并行设备编程的知识。也不用学习C++、STL,libtorch也没必要学了。
特别指出的是,尽管这段代码的背后都是Python,但是在API中完全掩盖了Python的痕迹。仔细观察发现它实际上调用了PyTorch,PyTorch的功能非常丰富(也有CUDA加速),只需要修改py_call的参数就能呼叫PyTorch中的任意功能。
关键技术
实际上这些C代码的背后都是Python和PyTorch的驱动。利用Cython所提供的C和Python混合编程的方法提供一个友好人性化的API。其中涉及到几个关键技术:
1. C数组与Python类型的转换
这个操作比较简单,Cython提供了非常好的支持
2. C数组与Numpy Array/Torch Tensor类型的转换
我们这里采用一个投机取巧的方法,即首先把C的数组强制转换成char *,即原始的内存空间,然后利用Cython的存储转换功能得到bytes类型的Python值,然后使用numpy array的frombuffer方法将bytes转换为array。得到numpy array之后,就可以比较容易的得到Torch Tensor以及复制到GPU当中了。这个过程没有内存复制。
反向转换也比较容易,首先先转换为Numpy的Array,然后用tobytes的方法得到存储值。
3. 调用PyTorch的函数
从Python API中返回的都是PyObject * 的类型,我们在API里掩人耳目,把PyObject * 进行了typedef,即代码里的 py 类型。、
调用PyTorch里的函数,我们用一种曲折的方法成功实现了Cython导出函数的变长参数,大家可以看看代码里是怎么实现的。用这种方法,可以调用Python库中的任何函数。
思考题
在之前的文章中提到,如果想在C中调用Python,必须调用Py_Initialize()进行初始化,但是今天我们的main函数里却没有这个调用。这个是如何实现的呢?(提示:不能用修改程序入口点的方法)
如果有想学习c++的程序员,可来我们的C/C++学习扣qun:589348389,
微信公众号:java大世界(现在是发布c++学习干货哦,没有java)
免费送C++的视频教程噢!
我每晚上8点还会在群内直播讲解C/C++知识,欢迎大家前来学习哦。