Tensorflow C++ API实现MatMul矩阵相乘操作
2022-02-28 本文已影响0人
FredricZhu
其中include/tf_/tensor_testutil.h和tensor_testutil.cc是tensorflow-cc库自带的文件,但是编库的时候莫名没有编进去,链接不上,所以单提出来做自定义库了。
然后我还在里面加了一个函数,
PrintTensorValue用来打印Tensor值,其他都是原生的。
CMakeLists.txt文件
cmake_minimum_required(VERSION 3.3)
project(test_tf_session)
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
include_directories(${INCLUDE_DIRS})
find_package(TensorflowCC REQUIRED)
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC)
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
include/tf_/tensor_testutil.h
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
#include <numeric>
#include <limits>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include <iostream>
namespace tensorflow {
namespace test {
// Constructs a scalar tensor with 'val'.
template <typename T>
Tensor AsScalar(const T& val) {
Tensor ret(DataTypeToEnum<T>::value, {});
ret.scalar<T>()() = val;
return ret;
}
// Constructs a flat tensor with 'vals'.
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals) {
Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
return ret;
}
template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor) {
// 打印Tensor值
T const* tensor_pt = tensor.unaligned_flat<T>().data();
auto size = tensor.NumElements();
os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
for(decltype(size) i=0; i<size; ++i) {
os << tensor_pt[i] << "\n";
}
return os;
}
template <typename OpType>
std::vector<Output> CreateReduceOP(Scope const& s, DataType tf_type, PartialTensorShape const& shape, bool keep_dims) {
std::vector<Output> outputs{};
auto input = ops::Placeholder(s.WithOpName("input"), tf_type, ops::Placeholder::Shape(shape));
auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
typename OpType::Attrs op_attrs;
op_attrs.keep_dims_ = keep_dims;
auto op = OpType(s.WithOpName("my_reduce"), input, axis, op_attrs);
outputs.emplace_back(std::move(input));
outputs.emplace_back(std::move(axis));
outputs.emplace_back(std::move(op));
return outputs;
}
// Constructs a tensor of "shape" with values "vals".
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
Tensor ret;
CHECK(ret.CopyFrom(AsTensor(vals), shape));
return ret;
}
// Fills in '*tensor' with 'vals'. E.g.,
// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
// test::FillValues<float>(&x, {11, 21, 21, 22});
template <typename T>
void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
auto flat = tensor->flat<T>();
CHECK_EQ(flat.size(), vals.size());
if (flat.size() > 0) {
std::copy_n(vals.data(), vals.size(), flat.data());
}
}
// Fills in '*tensor' with 'vals', converting the types as needed.
template <typename T, typename SrcType>
void FillValues(Tensor* tensor, std::initializer_list<SrcType> vals) {
auto flat = tensor->flat<T>();
CHECK_EQ(flat.size(), vals.size());
if (flat.size() > 0) {
size_t i = 0;
for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) {
flat(i) = T(*itr);
}
}
}
// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
// test::FillIota<float>(&x, 1.0);
template <typename T>
void FillIota(Tensor* tensor, const T& val) {
auto flat = tensor->flat<T>();
std::iota(flat.data(), flat.data() + flat.size(), val);
}
// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
// test::FillFn<float>(&x, [](int i)->float { return i*i; });
template <typename T>
void FillFn(Tensor* tensor, std::function<T(int)> fn) {
auto flat = tensor->flat<T>();
for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
}
// Expects "x" and "y" are tensors of the same type, same shape, and identical
// values (within 4 ULPs for floating point types unless explicitly disabled).
enum class Tolerance {
kNone,
kDefault,
};
void ExpectEqual(const Tensor& x, const Tensor& y,
Tolerance t = Tolerance ::kDefault);
// Expects "x" and "y" are tensors of the same (floating point) type,
// same shape and element-wise difference between x and y is no more
// than atol + rtol * abs(x). If atol or rtol is negative, the data type's
// epsilon * kSlackFactor is used.
void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0,
double rtol = -1.0);
// Expects "x" and "y" are tensors of the same type T, same shape, and
// equal values. Consider using ExpectEqual above instead.
template <typename T>
void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
ExpectEqual(x, y);
}
// Expects "x" and "y" are tensors of the same type T, same shape, and
// approximate equal values. Consider using ExpectClose above instead.
template <typename T>
void ExpectTensorNear(const Tensor& x, const Tensor& y, double atol) {
EXPECT_EQ(x.dtype(), DataTypeToEnum<T>::value);
ExpectClose(x, y, atol, /*rtol=*/0.0);
}
// For tensor_testutil_test only.
namespace internal_test {
::testing::AssertionResult IsClose(Eigen::half x, Eigen::half y,
double atol = -1.0, double rtol = -1.0);
::testing::AssertionResult IsClose(float x, float y, double atol = -1.0,
double rtol = -1.0);
::testing::AssertionResult IsClose(double x, double y, double atol = -1.0,
double rtol = -1.0);
} // namespace internal_test
} // namespace test
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_
include/tf_/tensor_testutil.cc
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_/tensor_testutil.h"
#include <cmath>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace test {
static ::testing::AssertionResult IsSameType(const Tensor& x, const Tensor& y) {
if (x.dtype() != y.dtype()) {
return ::testing::AssertionFailure()
<< "Tensors have different dtypes (" << x.dtype() << " vs "
<< y.dtype() << ")";
}
return ::testing::AssertionSuccess();
}
static ::testing::AssertionResult IsSameShape(const Tensor& x,
const Tensor& y) {
if (!x.IsSameSize(y)) {
return ::testing::AssertionFailure()
<< "Tensors have different shapes (" << x.shape().DebugString()
<< " vs " << y.shape().DebugString() << ")";
}
return ::testing::AssertionSuccess();
}
template <typename T>
static ::testing::AssertionResult EqualFailure(const T& x, const T& y) {
return ::testing::AssertionFailure()
<< std::setprecision(std::numeric_limits<T>::digits10 + 2) << x
<< " not equal to " << y;
}
static ::testing::AssertionResult IsEqual(float x, float y, Tolerance t) {
// We consider NaNs equal for testing.
if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
return ::testing::AssertionSuccess();
if (t == Tolerance::kNone) {
if (x == y) return ::testing::AssertionSuccess();
} else {
if (::testing::internal::CmpHelperFloatingPointEQ<float>("", "", x, y))
return ::testing::AssertionSuccess();
}
return EqualFailure(x, y);
}
static ::testing::AssertionResult IsEqual(double x, double y, Tolerance t) {
// We consider NaNs equal for testing.
if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
return ::testing::AssertionSuccess();
if (t == Tolerance::kNone) {
if (x == y) return ::testing::AssertionSuccess();
} else {
if (::testing::internal::CmpHelperFloatingPointEQ<double>("", "", x, y))
return ::testing::AssertionSuccess();
}
return EqualFailure(x, y);
}
static ::testing::AssertionResult IsEqual(Eigen::half x, Eigen::half y,
Tolerance t) {
// We consider NaNs equal for testing.
if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
return ::testing::AssertionSuccess();
// Below is a reimplementation of CmpHelperFloatingPointEQ<Eigen::half>, which
// we cannot use because Eigen::half is not default-constructible.
if (Eigen::numext::isnan(x) || Eigen::numext::isnan(y))
return EqualFailure(x, y);
auto sign_and_magnitude_to_biased = [](uint16_t sam) {
const uint16_t kSignBitMask = 0x8000;
if (kSignBitMask & sam) return ~sam + 1; // negative number.
return kSignBitMask | sam; // positive number.
};
auto xb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(x));
auto yb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(y));
if (t == Tolerance::kNone) {
if (xb == yb) return ::testing::AssertionSuccess();
} else {
auto distance = xb >= yb ? xb - yb : yb - xb;
const uint16_t kMaxUlps = 4;
if (distance <= kMaxUlps) return ::testing::AssertionSuccess();
}
return EqualFailure(x, y);
}
template <typename T>
static ::testing::AssertionResult IsEqual(const T& x, const T& y, Tolerance t) {
if (::testing::internal::CmpHelperEQ<T>("", "", x, y))
return ::testing::AssertionSuccess();
return EqualFailure(x, y);
}
template <typename T>
static ::testing::AssertionResult IsEqual(const std::complex<T>& x,
const std::complex<T>& y,
Tolerance t) {
if (IsEqual(x.real(), y.real(), t) && IsEqual(x.imag(), y.imag(), t))
return ::testing::AssertionSuccess();
return EqualFailure(x, y);
}
template <typename T>
static void ExpectEqual(const Tensor& x, const Tensor& y,
Tolerance t = Tolerance::kDefault) {
const T* Tx = x.unaligned_flat<T>().data();
const T* Ty = y.unaligned_flat<T>().data();
auto size = x.NumElements();
int max_failures = 10;
int num_failures = 0;
for (decltype(size) i = 0; i < size; ++i) {
EXPECT_TRUE(IsEqual(Tx[i], Ty[i], t)) << "i = " << (++num_failures, i);
ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up.";
}
}
template <typename T>
static ::testing::AssertionResult IsClose(const T& x, const T& y, const T& atol,
const T& rtol) {
// We consider NaNs equal for testing.
if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
return ::testing::AssertionSuccess();
if (x == y) return ::testing::AssertionSuccess(); // Handle infinity.
auto tolerance = atol + rtol * Eigen::numext::abs(x);
if (Eigen::numext::abs(x - y) <= tolerance)
return ::testing::AssertionSuccess();
return ::testing::AssertionFailure() << x << " not close to " << y;
}
template <typename T>
static ::testing::AssertionResult IsClose(const std::complex<T>& x,
const std::complex<T>& y,
const T& atol, const T& rtol) {
if (IsClose(x.real(), y.real(), atol, rtol) &&
IsClose(x.imag(), y.imag(), atol, rtol))
return ::testing::AssertionSuccess();
return ::testing::AssertionFailure() << x << " not close to " << y;
}
// Return type can be different from T, e.g. float for T=std::complex<float>.
template <typename T>
static auto GetTolerance(double tolerance) {
using Real = typename Eigen::NumTraits<T>::Real;
auto default_tol = static_cast<Real>(5.0) * Eigen::NumTraits<T>::epsilon();
auto result = tolerance < 0.0 ? default_tol : static_cast<Real>(tolerance);
EXPECT_GE(result, static_cast<Real>(0));
return result;
}
template <typename T>
static void ExpectClose(const Tensor& x, const Tensor& y, double atol,
double rtol) {
auto typed_atol = GetTolerance<T>(atol);
auto typed_rtol = GetTolerance<T>(rtol);
const T* Tx = x.unaligned_flat<T>().data();
const T* Ty = y.unaligned_flat<T>().data();
auto size = x.NumElements();
int max_failures = 10;
int num_failures = 0;
for (decltype(size) i = 0; i < size; ++i) {
EXPECT_TRUE(IsClose(Tx[i], Ty[i], typed_atol, typed_rtol))
<< "i = " << (++num_failures, i) << " Tx[i] = " << Tx[i]
<< " Ty[i] = " << Ty[i];
ASSERT_LT(num_failures, max_failures)
<< "Too many mismatches (atol = " << atol << " rtol = " << rtol
<< "), giving up.";
}
EXPECT_EQ(num_failures, 0)
<< "Mismatches detected (atol = " << atol << " rtol = " << rtol << ").";
}
void ExpectEqual(const Tensor& x, const Tensor& y, Tolerance t) {
ASSERT_TRUE(IsSameType(x, y));
ASSERT_TRUE(IsSameShape(x, y));
switch (x.dtype()) {
case DT_FLOAT:
return ExpectEqual<float>(x, y, t);
case DT_DOUBLE:
return ExpectEqual<double>(x, y, t);
case DT_INT32:
return ExpectEqual<int32>(x, y);
case DT_UINT32:
return ExpectEqual<uint32>(x, y);
case DT_UINT16:
return ExpectEqual<uint16>(x, y);
case DT_UINT8:
return ExpectEqual<uint8>(x, y);
case DT_INT16:
return ExpectEqual<int16>(x, y);
case DT_INT8:
return ExpectEqual<int8>(x, y);
case DT_STRING:
return ExpectEqual<tstring>(x, y);
case DT_COMPLEX64:
return ExpectEqual<complex64>(x, y, t);
case DT_COMPLEX128:
return ExpectEqual<complex128>(x, y, t);
case DT_INT64:
return ExpectEqual<int64>(x, y);
case DT_UINT64:
return ExpectEqual<uint64>(x, y);
case DT_BOOL:
return ExpectEqual<bool>(x, y);
case DT_QINT8:
return ExpectEqual<qint8>(x, y);
case DT_QUINT8:
return ExpectEqual<quint8>(x, y);
case DT_QINT16:
return ExpectEqual<qint16>(x, y);
case DT_QUINT16:
return ExpectEqual<quint16>(x, y);
case DT_QINT32:
return ExpectEqual<qint32>(x, y);
case DT_BFLOAT16:
return ExpectEqual<bfloat16>(x, y, t);
case DT_HALF:
return ExpectEqual<Eigen::half>(x, y, t);
default:
EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
}
}
void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
ASSERT_TRUE(IsSameType(x, y));
ASSERT_TRUE(IsSameShape(x, y));
switch (x.dtype()) {
case DT_HALF:
return ExpectClose<Eigen::half>(x, y, atol, rtol);
case DT_BFLOAT16:
return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
case DT_FLOAT:
return ExpectClose<float>(x, y, atol, rtol);
case DT_DOUBLE:
return ExpectClose<double>(x, y, atol, rtol);
case DT_COMPLEX64:
return ExpectClose<complex64>(x, y, atol, rtol);
case DT_COMPLEX128:
return ExpectClose<complex128>(x, y, atol, rtol);
default:
EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
}
}
::testing::AssertionResult internal_test::IsClose(Eigen::half x, Eigen::half y,
double atol, double rtol) {
return test::IsClose(x, y, GetTolerance<Eigen::half>(atol),
GetTolerance<Eigen::half>(rtol));
}
::testing::AssertionResult internal_test::IsClose(float x, float y, double atol,
double rtol) {
return test::IsClose(x, y, GetTolerance<float>(atol),
GetTolerance<float>(rtol));
}
::testing::AssertionResult internal_test::IsClose(double x, double y,
double atol, double rtol) {
return test::IsClose(x, y, GetTolerance<double>(atol),
GetTolerance<double>(rtol));
}
} // end namespace test
} // end namespace tensorflow
test_tf_session/tf_matmul_test.cpp
#include <string>
#include <vector>
#include <glog/logging.h>
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"
using namespace tensorflow;
int main(int argc, char** argv) {
FLAGS_log_dir = "./";
FLAGS_alsologtostderr = true;
// 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
FLAGS_minloglevel = 0;
Debug::DeathHandler dh;
google::InitGoogleLogging("./logs.log");
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
TEST(TfMatNulTests, MatMul) {
// MatMul 测试
Scope root = Scope::NewRootScope();
auto a = ops::Fill(root, {3, 2}, 1.0f);
auto b = ops::Fill(root, {2, 3}, 3.0f);
auto mat_mul_ = ops::MatMul(root, a, b);
ClientSession session(root);
std::vector<Tensor> outputs;
session.Run({a, b, mat_mul_}, &outputs);
test::PrintTensorValue<float>(std::cout, outputs[0]);
test::PrintTensorValue<float>(std::cout, outputs[1]);
test::PrintTensorValue<float>(std::cout, outputs[2]);
std::vector<float> v_a(6, 1.0f);
test::ExpectTensorEqual<float>(outputs[0], test::AsTensor<float>({v_a.data(), v_a.size()}, {3, 2}));
std::vector<float> v_b(6, 3.0f);
test::ExpectTensorEqual<float>(outputs[1], test::AsTensor<float>({v_b.data(), v_b.size()}, {2, 3}));
std::vector<float> v_mat_mul(9, 6.0f);
test::ExpectTensorEqual<float>(outputs[2], test::AsTensor<float>({v_mat_mul.data(), v_mat_mul.size()}, {3, 3}));
}
程序输出如下,
data:image/s3,"s3://crabby-images/2c9af/2c9af49b235948b95cb70d862efc8d19a417660a" alt=""