使用Caffe的Python接口进行推理

2018-11-25  本文已影响0人  crazyhank

从网上下载下来Caffe代码进行编译后,在Caffe的源码目录下有一个python子目录,把这个子目录的路径加入到PYTHONPATH环境变量中去,然后就可以使用Caffe的Python API了,如下所示:

$ export PYTHONPATH=/home/hank/Study/CV/caffe.git/python/

打开Python,执行import caffe语句,如果没有报错,说明就可以正常使用了,如下:

hank@hank-desktop:~/Study/CV$ python
Python 2.7.6 (default, Nov 23 2017, 15:49:48)
[GCC 4.8.4] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import caffe
>>>

接下来我们使用Caffe的Python API对一个预训练好的模型进行加载并进行推理测试,在Caffe的源码目录下有一个下载由Caffe开发团队预训练好的模型的脚本,这里我们以caffenet模型为例,先执行下载脚本,如下:

$sudo python ./scripts/download_model_binary.py models/bvlc_reference_caffenet

下载完成后在models/bvlc_reference_caffenet目录下就会看到对应的模型(caffemodel文件以及prototxt文件),如下:

hank@hank-desktop:~/Study/CV/caffe.git$ ls -lh models/bvlc_reference_caffenet/
total 233M
-rw-rw-r-- 1 hank hank 233M 11月 24 23:23 bvlc_reference_caffenet.caffemodel
-rw-rw-r-- 1 hank hank 2.9K 11月 25 20:56 deploy.bak.prototxt
-rw-rw-r-- 1 hank hank 2.9K 11月 25 20:56 deploy.prototxt
-rw-rw-r-- 1 hank hank 1.3K  9月  9 08:46 readme.md
-rw-rw-r-- 1 hank hank  315  9月  9 08:46 solver.prototxt
-rw-rw-r-- 1 hank hank 5.6K  9月  9 08:46 train_val.prototxt

如果你要下载其他的模型,比如alexnet,可以执行以下脚本:

$ sudo python ./scripts/download_model_binary.py ./models/bvlc_alexnet/

模型下载成功后,我们使用python语言来编写推理的程序,先把源码贴出来:

#coding = utf-8
import numpy as np
import sys, os

caffe_root = '/home/hank/Study/CV/caffe.git/'
sys.path.insert(0, caffe_root + 'python')

import caffe

os.chdir(caffe_root)

#caffe.set_device(0)
#caffe.set_mode_gpu()

net_file = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt'
caffe_model = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
mean_file = caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'

# load model
net = caffe.Net(net_file, caffe_model, caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2, 0, 1))
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2, 1, 0))

im = caffe.io.load_image(caffe_root+'examples/images/cat.jpg')
net.blobs['data'].data[...] = transformer.preprocess('data',im)
out = net.forward()

print net.blobs['data'].data.shape
print net.blobs['data'].data.dtype
print net.blobs['prob'].data.shape
print net.blobs['prob'].data.dtype


imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'
labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')

top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1]
for i in np.arange(top_k.size):
    print top_k[i], labels[top_k[i]]

使用caffe源码里自带的一张cat照片进行推理(分类任务),最后输出的结果如下:

(1, 3, 227, 227)
float32
(1, 1000)
float32
281 n02123045 tabby, tabby cat
282 n02123159 tiger cat
285 n02124075 Egyptian cat
277 n02119022 red fox, Vulpes vulpes
287 n02127052 lynx, catamount
上一篇 下一篇

猜你喜欢

热点阅读