使用ops::DecodeCSV算子重写鸢尾花数据集预测
2022-03-22 本文已影响0人
FredricZhu
本文使用ops::DecodeCSV算子重写鸢尾花数据集预测,这样就不需要依赖三方的hmdf::DataFrame了。
程序结构如下,
图片.png
conanfile.txt
[requires]
gtest/1.10.0
glog/0.4.0
protobuf/3.9.1
dataframe/1.20.0
[generators]
cmake
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(test_iris_predict)
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
include_directories(${INCLUDE_DIRS})
find_package(TensorflowCC REQUIRED)
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC)
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
tf_iris_model_test.cpp
#include <fstream>
#include <tensorflow/c/c_api.h>
#include "death_handler/death_handler.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
using namespace tensorflow;
using BatchDef = std::initializer_list<tensorflow::int64>;
char const* data_csv = "../data/iris.csv";
int main(int argc, char** argv) {
Debug::DeathHandler dh;
::testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
return ret;
}
// 获取CSV文件行
std::vector<tstring> GetCSVLines() {
std::fstream ifs {data_csv};
std::string line;
std::vector<tstring> lines;
while(std::getline(ifs, line)) {
lines.emplace_back(tstring(line));
}
return lines;
}
Tensor GetInputTensor() {
// 生成iris数据集
// https://www.tensorflow.org/versions/r2.6/api_docs/cc/class/tensorflow/ops/decode-c-s-v
Scope root = Scope::NewRootScope();
ClientSession session(root);
auto lines = GetCSVLines();
auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
// DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
// 1. Decode CSV成列张量
auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});
// 2. Reshape 成 (150, 1), 便于按行concat
auto input_1 = ops::Reshape(root, decode_csv_op.output[1], {150, 1});
auto input_2 = ops::Reshape(root, decode_csv_op.output[2], {150, 1});
auto input_3 = ops::Reshape(root, decode_csv_op.output[3], {150, 1});
auto input_4 = ops::Reshape(root, decode_csv_op.output[4], {150, 1});
// 3. 按行 concat成 (150, 4),用于iris数据集
auto concat_op = ops::Concat(root, {Input(input_1), Input(input_2), Input(input_3), Input(input_4)}, {1});
// 4. Client Session Run,出结果
std::vector<Tensor> outputs_concat {};
session.Run({concat_op}, &outputs_concat);
return outputs_concat[0];
}
std::vector<int> GetOutputBatches() {
Scope root = Scope::NewRootScope();
auto lines = GetCSVLines();
auto input = test::AsTensor<tensorflow::tstring>(lines, {(long)lines.size()});
// DecodeCSV函数使用Default Value来推算 输出张量的列数 和类型,不能随便填
auto decode_csv_op = ops::DecodeCSV(root, input, {Input(1), Input(1.0f), Input(1.0f), Input(1.0f), Input(1.0f), Input(1)});
ClientSession session(root);
std::vector<Tensor> outputs;
session.Run(decode_csv_op.output, &outputs);
return test::GetTensorValue<int>(outputs[5]);
}
std::vector<int> ConvertTensorToIndexValue(Tensor const& tensor_) {
auto tensor_res = test::GetTensorValue<float>(tensor_);
std::vector<int> predict_res{};
for(int i=0; i<tensor_res.size(); ++i) {
if(i!=0 && (i+1)%3==0) {
auto max_idx = std::max_element(tensor_res.begin() + (i-2), tensor_res.begin() + (i+1)) -(tensor_res.begin() + (i-2));
predict_res.emplace_back((int)max_idx);
}
}
return predict_res;
}
TEST(TfIrisModelTest, LoadAndPredict) {
SavedModelBundleLite bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir = "../iris_model";
TF_CHECK_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
auto input_tensor = GetInputTensor();
std::vector<tensorflow::Tensor> out_tensors;
TF_CHECK_OK(bundle.GetSession()->Run({{"serving_default_input_1:0", input_tensor}},
{"StatefulPartitionedCall:0"}, {}, &out_tensors));
std::cout << "Print Tensor Value\n";
test::PrintTensorValue<float>(std::cout, out_tensors[0], 3);
std::cout << "\n";
std::cout << "Print Index Value\n";
auto predict_res = ConvertTensorToIndexValue(out_tensors[0]);
for(auto ele: predict_res) {
std::cout << ele << "\n";
}
auto labels = GetOutputBatches();
int correct {0};
for(int i=0; i<predict_res.size(); ++i) {
if(predict_res[i] == labels[i]) {
++ correct;
}
}
std::cout << "Total correct: " << correct << "\n";
std::cout << "Total datasets: " << labels.size() << "\n";
std::cout << "Accuracy is: " << (float)(correct)/labels.size() << "\n";
}
程序输出如下,
图片.png