GammaPerfMNN实践

python的多维向量操作没那么神 —— 异Shape多维向量点

2019-09-29  本文已影响0人  Chriszzzz

1 神奇的Python多维向量操作

用惯 python 的同学都知道,python 对于多维向量操作的灵活性可谓是 有!如!神!迹!
来来,眼见为实!

python多维向量的操作示意

即便 2个多维向量输入的形状不同,它们依然可以进行计算并得到一个我们 相对期望的结果
相关的 python 代码如下,大家也可以自己再体验下。

import numpy as np

aa = np.full((3,4,8), 1);

bb = np.full((3,4,8), 2);
cc = aa + bb;
print(cc)

bb = np.full((1,1,1), 2);
cc = aa + bb;
print(cc)

bb = np.full((1,4,8), 2);
cc = aa + bb;
print(cc)

bb = np.full((3,1,1), 2);
cc = aa + bb;
print(cc)

然而,这样的灵活很难实现么?其实都是说穿不如一文钱的啦~走着!

2 思路整理

代码要写,但不要着急,咱们先把需求搞明白,思路想清楚。

要实现 异Shape多维向量点加,我们要引入一个 维度步幅(Stride) 的概念,太难解释了。先记住它是要点就好了,跟着我继续实践中去理解。

我们先从 一维的 异Shape操作 看起。

2.1 一维点加

一维点加

如图,一个 长度为5的数组(数组A) 和一个 长度为1的数组(数组B) 相加,我们期望 数组A 的每个值和 数组B 的唯一一个值分别相加。

for (int i = 0; i < 5; i++) {
        arrayC[i * strideC] = arrayA[i * strideA] + arrayB[i * strideB];
}

如上面的伪代码,一个循环就可以解决
大家着重关注下图中的 strideB = 0,然后体会一下~

2.2 二维点加

二维点加

如图,当维度扩展为二维时,我们需要 为每个维度 提供一个数组来管理 stride
此时我们的需求 可以用两层循环实现,伪代码如下:

    for (int i = 0; i < 5; i++) {
        pic = arrayC + i * strideC[1];
        pia = arrayA + i * strideA[1];
        pib = arrayB + i * strideB[1];
        
        for (int j = 0; j < 3; j++) {
            pjc = pic + j * strideC[0];
            pja = pia + j * strideA[0];
            pjb = pib + j * strideB[0];
            
            *pjc = *pja + *pjb;
        }
    }

Tips: 一定要逐行代码结合图示分析一下,这种多维的操作太难不实践直接讲清楚了!

那么我们再换一个场景(如下图),大家可以自己套用上面伪代码的逻辑自己推导下~

二维练习图

2.3 三维点加

三维点加

大家用数学归纳自己分析一下,我懒懒地画出图示,列出来stride信息,就不具体分析咯~
伪代码如下,一个三层循环:

for (int i = 0; i < 5; i++) {
        pic = arrayC + i * strideC[2];
        pia = arrayA + i * strideA[2];
        pib = arrayB + i * strideB[2];
        
        for (int j = 0; j < 16; j++) {
            pjc = pic + j * strideC[1];
            pja = pia + j * strideA[1];
            pjb = pib + j * strideB[1];
            
            for (int k = 0; k < 8; k++) {
                pkc = pjc + k * strideC[0];
                pka = pja + k * strideA[0];
                pkb = pjb + k * strideB[0];
                
                *pkc = *pka + *pkb;
            }
        }
    }

外层 i循环 定位面,中层 j循环 定位行,内层 k循环 定位列。

2.4 多维

再一维一维地讲就是浪费篇幅咯一样,感兴趣的话用数学归纳自己推导下并不复杂~

3 参考 MNN BinaryOp 的实现

MNNBinaryOp 操作中,即用到了这种算法思路,不过:

  1. 它固定了 最高支持计算维度为6(很OK了,我们一般也没有超过6维计算的需求)
  2. 它的 外层循环是低维度的变化,内层循环是高纬度的变化(这和我上面介绍的思路 正好相反,所以如果要分析源码时候要注意不要因为算法对不上儿让自己晕掉~)

关键源码函数如下:

// MNN binaryOp 关键计算操作源码

template <typename Tin, typename Tout, typename Func>
static ErrorCode _binaryOp(Tensor* input0, Tensor* input1, Tensor* output) {
    Func f;

    const int input0DataCount = input0->elementSize();
    const int input1DataCount = input1->elementSize();

    const Tin* input0Data = input0->host<Tin>();
    const Tin* input1Data = input1->host<Tin>();
    Tout* outputData      = output->host<Tout>();

    if (input0DataCount == 1) { // data count == 1, not only mean scalar input, maybe of shape (1, 1, 1, ...,1)
        for (int i = 0; i < input1DataCount; i++) {
            outputData[i] = static_cast<Tout>(f(input0Data[0], input1Data[i]));
        }
    } else if (input1DataCount == 1) {
        for (int i = 0; i < input0DataCount; i++) {
            outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[0]));
        }
    } else { // both input contains more than one element,which means no scalar input
        bool sameShape = input0->elementSize() == input1->elementSize();
        if (sameShape) { // two inputs have the same shape, apply element-wise operation
            for (int i = 0; i < input0DataCount; i++) {
                outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[i]));
            }
        } else { // not the same shape, use broadcast
#define MAX_DIM 6
            MNN_ASSERT(output->dimensions() <= MAX_DIM);
            int dims[MAX_DIM];
            int stride[MAX_DIM];
            int iStride0[MAX_DIM];
            int iStride1[MAX_DIM];
            for (int i = MAX_DIM - 1; i >= 0; --i) {
                dims[i]     = 1;
                stride[i]   = 0;
                iStride0[i] = 0;
                iStride1[i] = 0;
                int input0I = i - (output->dimensions() - input0->dimensions());
                int input1I = i - (output->dimensions() - input1->dimensions());
                if (i < output->dimensions()) {
                    dims[i]   = output->length(i);
                    stride[i] = output->stride(i);
                }
                if (input0I >= 0 && input0->length(input0I) != 1) {
                    iStride0[i] = input0->stride(input0I);
                }
                if (input1I >= 0 && input1->length(input1I) != 1) {
                    iStride1[i] = input1->stride(input1I);
                }
            }
            for (int w = 0; w < dims[5]; ++w) {
                auto ow  = outputData + w * stride[5];
                auto i0w = input0Data + w * iStride0[5];
                auto i1w = input1Data + w * iStride1[5];
#define PTR(x, y, i)                      \
    auto o##x  = o##y + x * stride[i];    \
    auto i0##x = i0##y + x * iStride0[i]; \
    auto i1##x = i1##y + x * iStride1[I]

                for (int v = 0; v < dims[4]; ++v) {
                    PTR(v, w, 4);
                    for (int u = 0; u < dims[3]; ++u) {
                        PTR(u, v, 3);
                        for (int z = 0; z < dims[2]; ++z) {
                            PTR(z, u, 2);
                            for (int y = 0; y < dims[1]; ++y) {
                                PTR(y, z, 1);
                                for (int x = 0; x < dims[0]; ++x) {
                                    PTR(x, y, 0);
                                    *ox = static_cast<Tout>(f(*i0x, *i1x));
                                }
                            }
                        }
                    }
                }
            }
#undef MAX_DIM
#undef PTR
        }
        // broadcast-capable check is done in compute size
    }

    return NO_ERROR;
}

4 可调式的Demo

MNN 的源码确实不太容易阅读,毕竟它在实现算法的同时:

  1. 考虑的不同操作的兼容(不仅仅支持加法)
  2. 考虑了不同数据类型的兼容
  3. 基于 MNN 的数据结构

但是不用担心,像往常一样,我的技术文章一般都会为大家配套一份简单的参考代码,这份代码相比 MNN 源码:

  1. 只支持加法
  2. 只支持 float 操作类型
  3. 数据结构即 float *std::vector<int> 的组合
  4. 不限维度的操作(即你可以进行 20维的 float 相加操作)
  5. 左加数(A)Shape 可以和 输出(C)Shape 不同(详见 【5 思考 】
// cymv_add 点加主函数
// __add 为支持不限维度操作而实现的递归子函数

static void __add(int dimTag,
                  float *pC, std::vector<int> &rev_stepCs,
                  float *pA, std::vector<int> &rev_stepAs, std::vector<int> &dimAs,
                  float *pB, std::vector<int> &rev_stepBs, std::vector<int> &dimBs) {
    
    int dimNum = (int)dimAs.size();
    
    int curDimA = dimAs[dimTag];
    int curDimB = dimBs[dimTag];
    int curDimC = curDimA > curDimB ? curDimA : curDimB;
    
    int curStepA = rev_stepAs[dimNum - 1 - dimTag];
    int curStepB = rev_stepBs[dimNum - 1 - dimTag];
    int curStepC = rev_stepCs[dimNum - 1 - dimTag];
    
    float *tmppa = pA;
    float *tmppb = pB;
    float *tmppc = pC;
    for (int i = 0; i < curDimC; i++) {
        
        if (dimTag == dimNum - 1) {
            *tmppc = *tmppa + *tmppb;
        } else {
            __add(dimTag + 1,
                  tmppc, rev_stepCs,
                  tmppa, rev_stepAs, dimAs,
                  tmppb, rev_stepBs, dimBs);
        }
        
        tmppc += curStepC;
        tmppa += curStepA;
        tmppb += curStepB;
    }
}

void cymv_add(float *dst,
              float *src0,
              std::vector<int> shape0,
              float *src1,
              std::vector<int> shape1) {
    
    if (shape0.size() != shape1.size()) {
        printf("维度不等,无法计算");
        return;
    }
    
    /* 小维度在前,高维度在后 */
    std::vector<int> step0;
    std::vector<int> step1;
    std::vector<int> stepOut;
    
    int tmpStep0 = 1;
    int tmpStep1 = 1;
    int tmpStepOut = 1;
    for (int i = (int)(shape0.size()) - 1; i >= 0 ; i--) {
        
        if (1 == shape0[i]) {
            step0.push_back(0);
        } else {
            step0.push_back(tmpStep0);
        }
        if (1 == shape1[i]) {
            step1.push_back(0);
        } else {
            step1.push_back(tmpStep1);
        }
        stepOut.push_back(tmpStepOut);
        
        tmpStep0 *= shape0[i];
        tmpStep1 *= shape1[i];
        int maxVal = shape0[i] > shape1[i] ? shape0[i] : shape1[i];
        tmpStepOut *= maxVal;
    }
    
    __add(0,
          dst, stepOut,
          src0, step0, shape0,
          src1, step1, shape1);
}

GitHub可调式源码链接

当然,我说好读也是相对的,一样要费点心思哦!但好在有源码能调试嘛!

5 思考

最后,这样的场景我们的算法能支持么?

思考

下载 Demo 试试看咯!

上一篇下一篇

猜你喜欢

热点阅读