avx256 计算点积

2021-12-22  本文已影响0人  GOGOYAO
#include <immintrin.h>

#include <iostream>
#include <vector>
bool DotProductSimd256(const std::vector<float>& emb_1,
                       const std::vector<float>& emb_2, double& res) {
  const static size_t kBlockWidth = 8;  // compute 8 floats in one loop
  res = 0;
  const float* a = emb_1.data();
  const float* b = emb_2.data();
  int k = emb_1.size() / kBlockWidth;
  for (int i = 0; i < k; i++) {
    __m256 ai = _mm256_loadu_ps(a + i * kBlockWidth);
    __m256 bi = _mm256_loadu_ps(b + i * kBlockWidth);
    __m256 r = _mm256_dp_ps(ai, bi, 0xF1);
    // debug for the output;
    // for (int i = 0; i < 8; i++) std::cout << r[i] << std::endl;
    res += (r[0] + r[4]);
  }
  return true;
}

bool DotProduct(const std::vector<float>& emb_1,
                const std::vector<float>& emb_2, double& res) {
  res = 0;
  for (int i = 0; i < emb_1.size(); i++) {
    res = res + (emb_1[i] * emb_2[i]);
  }
  return true;
}

int main(int argc, char* argv[]) {
  std::vector<float> f_vec_1;
  std::vector<float> f_vec_2;
  const int len = 8;
  for (int i = 0; i < len; i++) {
    f_vec_1.push_back(0.01 * (i + 1));
    f_vec_2.push_back(0.01 * (i + 1));
  }

  double res = 0;
  DotProductSimd256(f_vec_1, f_vec_2, res);
  std::cout << "use avx res = " << res << std::endl;

  DotProduct(f_vec_1, f_vec_2, res);
  std::cout << "not use avx res = " << res << std::endl;
  return 0;
}

编译方法:gcc -mavx2 -std=c++11 -g -lstdc++ main.cpp
查看机器是否支持avx2:lscpu | grep avx2
simd性能提升一倍以上。
ps:如果开启-O3编译选项,由于编译器的优化效果,simd版本可能会没有任何优势

上一篇下一篇

猜你喜欢

热点阅读