pytroch学习(二十一)—C++(libTorch)调用py
前言
当我们训练好一个CNN模型之后,可能要集成到项目工程中,或者移植到到不同的开发平台(比如Android, IOS), 一般项目工程或者App大多数采用C/C++, Java等语言,但是采用pytroch训练的模型用的是python语言,这样就存在一个问题,如何使用C/C++、Java调用预训练好的模型, 如果解决了这个问题,那么训练好的模型才可以走出实验室,在App中得到广泛应用。
本章内容将完整介绍pytroch
预训练模型的C++调用,关于Java的调用,其实也不难,如果掌握了C++调用 pytorch模型的方法, Java可以通过JNI调用C++。
开发环境
- Ubuntu 18.04
- Clion
- CMake
- opencv
- libTorch
配置libTorch
首先,在pytorch官网下载libtroch, 官网提供了win/linux/Mac系统编译好的库,省去了编译库的过程。
image.png下载好之后,解压到一个路径即可。
image.png image.png image.png
测试一个简单demo
创建一个目录,example-app
image.png新建2个文件
image.png
- example-app.cpp
#include <torch/torch.h>
#include <iostream>
int main() {
torch::Tensor tensor = torch::rand({2, 3});
std::cout << tensor << std::endl;
}
- CMakeLists.txt
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
然后打开终端,输入:mkdir build
继续: cd build
开始编译:
image.png image.png
执行:
image.png
libTorch调用预训练好的性别分类模型
上面的例子是pytorch官网的demo, 下面本人模仿官方的demo, 将使用libTorch C++ API调用自己预训练好的.pth
模型。
第一步
将预训练好的模型转换
import torch
import torchvision.models as models
import torch.nn as nn
# 载入预训练的型
model = models.squeezenet1_1(pretrained=False)
model.classifier[1] = nn.Conv2d(in_channels=512, out_channels=2, kernel_size=(1, 1), stride=(1, 1))
model.num_classes = 2
model.load_state_dict(torch.load('./model/model_squeezenet_utk_face_20.pth', map_location='cpu'))
model.eval()
print(model)
# model = models.resnet18(pretrained=True)
# print(model)
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1, 3, 224, 224))
print(output)
# ----------------------------------
traced_script_module.save("./model/model_squeezenet_utkface.pt")
第二步
编写C++代码
- 准备好上一步骤转换的模型文件
- 准备几张测试图像
由于涉及到图像的加载与处理,本人使用opencv进行读取和处理。
Tips:
训练过程中,采用PIL.Image
加载图像(3通道 RGB),然后Resize
到224 x 224大小, 之后再进行ToTensor
。因此使用C++ libTorch时候也需要按照上述过程对图像进行预处理。
-
cv::imread()
默认读取为三通道BGR,需要进行B/R通道交换,这里采用cv::cvtColor()
实现。 -
缩放
cv::resize()
实现。 -
opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用tensor.permut()
实现,效果是一样的。 -
数据归一化,采用
tensor.div(255)
实现。
#include <torch/script.h> // One-stop header.
#include <opencv2/opencv.hpp>
#include <iostream>
#include <memory>
//https://pytorch.org/tutorials/advanced/cpp_export.html
string image_path = "/home/weipenghui/Project-dev/Cpp/testLibTorch2/image";
int main(int argc, const char* argv[]) {
// Deserialize the ScriptModule from a file using torch::jit::load().
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("/home/weipenghui/Project-dev/Cpp/testLibTorch2/model/model_squeezenet_utkface.pt");
assert(module != nullptr);
std::cout << "ok\n";
//输入图像
auto image = cv::imread(image_path +"/"+ "7.jpg",cv::ImreadModes::IMREAD_COLOR);
cv::Mat image_transfomed;
cv::resize(image, image_transfomed, cv::Size(224, 224));
cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);
// 图像转换为Tensor
torch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols,3},torch::kByte);
tensor_image = tensor_image.permute({2,0,1});
tensor_image = tensor_image.toType(torch::kFloat);
tensor_image = tensor_image.div(255);
tensor_image = tensor_image.unsqueeze(0);
// 网络前向计算
// Execute the model and turn its output into a tensor.
at::Tensor output = module->forward({tensor_image}).toTensor();
auto max_result = output.max(1,true);
auto max_index = std::get<1>(max_result).item<float>();
if (max_index == 0){
cv::putText(image, "male", cv::Point(50, 50), 1, 1,cv::Scalar(0, 255, 255));
}else{
cv::putText(image, "female", cv::Point(50, 50), 1, 1,cv::Scalar(0, 255, 255));
}
cv::imwrite("./result7.jpg", image);
//cv::imshow("image", image);
//cv::waitKey(0);
// at::Tensor prob = torch::softmax(output,1);
// auto prediction = torch::argmax(output, 1);
//
// auto aa = prediction.slice(/*dim=*/0, /*start=*/0, /*end=*/2).item();
//
// std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/2) << '\n';
// std::cout << prob.slice(/*dim=*/1, /*start=*/0, /*end=*/2) << '\n';
// std::cout <<prediction.slice(/*dim=*/0, /*start=*/0, /*end=*/2)<<"\n";
}
编写CMakeLists.txt文件
目的是将libTorch
, opencv
配置好,确保程序可以正常编译链接。
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
#include_directories(/home/weipenghui/Lib-dev/opencv_3.4.3_contrib/opencv-3.4.3/build/install_cv/include)
# set(CMAKE_PREFIX_PATH "/home/weipenghui/Lib-dev/opencv_3.4.3_contrib/opencv-3.4.3/build/install_cv")
find_package(OpenCV REQUIRED)
set(CMAKE_PREFIX_PATH
/home/weipenghui/Lib-dev/libtorch-shared-with-deps-latest/libtorch
/home/weipenghui/Lib-dev/opencv_3.4.3_contrib/opencv-3.4.3/build/install_cv)
find_package(Torch REQUIRED)
add_executable(example-app main.cpp)
target_link_libraries(example-app ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
识别结果
result.jpg result2.jpg result3.jpg result4.jpg result5.jpg result6.jpg result7.jpgEnd
参考:
https://pytorch.org/cppdocs/installing.html
https://pytorch.org/tutorials/advanced/cpp_export.html