部署tensorflow serving+python,java

2019-01-30  本文已影响0人  wxrg2012

本文介绍使用docker的方法部署tensorflow serving,并提供pythonjava client代码实例。(本文参考了较多博文和tensorflow官方文档,旨在补充多数博文遗留的坑,和精简官方文档的繁琐)。
为了避免bazel编译源码这个大坑(会报一些奇怪的错误,主要是各个依赖项的版本不对应),本文直接选择docker的方式部署tensorflow serving。
注:只需按照步骤一步一步来,就能从零到部署成功,最后会提供一个使用案例:文本分类模型

1 Docker安装

1.1 Mac环境下安装

参考网站
建议选择手动安装,安装完毕后,选择(Check for Updates)更新到最新版本

1.2 centos环境下安装

前提条件:CentOS 7 上,要求系统为64位、系统内核版本为 3.10 以上,通过指令uname -r 查看自己的系统版本

移除旧的版本:

$ sudo yum remove docker \
                  docker-client \
                  docker-client-latest \
                  docker-common \
                  docker-latest \
                  docker-latest-logrotate \
                  docker-logrotate \
                  docker-selinux \
                  docker-engine-selinux \
                  docker-engine

安装依赖项:
sudo yum install -y yum-utils device-mapper-persistent-data lvm2
添加源信息:
sudo yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo
更新 yum 缓存:
sudo yum makecache fast
安装 Docker-ce:
sudo yum -y install docker-ce
启动 Docker 后台服务:
sudo systemctl start docker
测试运行 hello-world:
docker run hello-world 或者 直接查看版本 docker --version

2 serving部署

2.1 拉取serving 镜像

docker pull tensorflow/serving
完成之后 查看安装好的镜像
docker images

2.2 导出模型

serving不能直接使用以HDF5和.ckpt方式保存的模型,需要进行一次转化,本文以keras保存的HDF5文件为例进行介绍,.ckpt转换方式大同小异,游客可自行查询。

import tensorflow as tf
from keras import backend as K
from keras.models import Sequential, Model
from os.path import isfile
from keras.models import load_model
import os

def save_model_to_serving(model, export_version, export_path='prod_models'):
    print(model.input, model.output)
    signature = tf.saved_model.signature_def_utils.predict_signature_def(
        inputs={'textdata': model.input}, outputs={'market': model.output})
    export_path = os.path.join(
        tf.compat.as_bytes(export_path),
        tf.compat.as_bytes(str(export_version)))
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
    builder.add_meta_graph_and_variables(
        sess=K.get_session(),
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'market_classification': signature,
        },
        legacy_init_op=legacy_init_op)
    builder.save()
model = load_model('自己的路径/blistm-checkpoint-02e-val_acc_0.96.hdf5')
save_model_to_serving(model, "1", "bgru_serving")#bgru_serving表示转换后的模型会存储到该路径下

模型转化结束后会生成下面几个文件


2.3 运行容器

docker run -p 8500:8500 \
      --mount type=bind,source=自己的路径/bgru_serving/,target=/models/market_blstm \
      -e MODEL_NAME=market_blstm -t tensorflow/serving

注:测试建议使用8500端口 ,自己的路径->绝对路径 (重点)
各个参数的含义:

  • -p 8500:8500 :指的是开放8500这个gRPC端口
  • --mount type=bind, source=自己的路径/bgru_serving/, target=/models/market_blstm:把你导出的本地模型文件夹挂载到docker container的/models/market_blstm这个文件夹,tensorflow serving会从容器内的/models/market_blstm文件夹里面找到你的模型
  • --MODEL_NAME:模型名字,在导出模型的时候设置的名字
  • -t 指定使用tensorflow/serving这个镜像,可以替换其他版本,例如tensorflow/serving:latest-gpu,但你需要docker pull tensorflow/serving:latest-gpu把这个镜像拉下来

3 client客户端

3.1 python 案例

注:最好使用python3.5+,不然如果使用高版本的tensorflow会报错
安装依赖库sudo pip3 install tensorflow-serving-api
客户端代码

from __future__ import print_function
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
import re,json,jieba,time
import codecs
import random,time

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2

def loadData(filename):  #加载json文件 生成字典
    with codecs.open(filename,'r','utf-8') as fr:
        resdict = json.load(fr)
    return resdict


vocab = loadData('vocab_bgru.dict')#   加载词典 ,格式:"中国":12045

def denoise(text): #文本预处理并粉刺,再根据embegging所需的词典生成词的索引矩阵----处理单条文本数据
    x_train_word_ids = []
    tem = []
    patten=re.compile(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b',re.S)
    line = text.strip()
    line = patten.sub('',line.decode("utf-8","ignore"))
    line = re.sub(r'{url(.*)网页链接}','',line)
    line = line.replace('\\','').replace('\n',' ').replace('https://',' ')
    wordlist = [emt.strip() for emt in jieba.cut(line) if len(emt.strip())>=2]
    for i,word in enumerate(wordlist):
        try:code = vocab[word]
        except:
            try:code = vocab[word.encode('utf-8')]
            except:continue
        tem.append(code)
        x_train_word_ids.append(tem)
    if len(x_train_word_ids)==0:return [[0]]
    return x_train_word_ids

def pad_sequences(x_train_word_ids,maxlen=64): #根据denoise函数得到的一条文本的索引矩阵生成符合lstm输入的词向量
    len_x = len(x_train_word_ids[0])
    if len_x>maxlen:
        res = [x_train_word_ids[0][i] for i in range(len_x-maxlen,len_x)]
        return res
    else:
        res = [0]*maxlen
        for i,emt in enumerate(x_train_word_ids[0]):
            res[maxlen-len_x+i]=emt
        return res


tf.app.flags.DEFINE_string('server', '127.0.0.1:8500',
                           'PredictionService host:port') #ip和端口,ip可换成要连接的服务器ip
FLAGS = tf.app.flags.FLAGS

start_time = time.time()

batch_size = 120

host,port = FLAGS.server.split(":")

channel = implementations.insecure_channel(host,int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

request = predict_pb2.PredictRequest()
request.model_spec.name = 'market_blstm' # 这个name跟tensorflow_model_server  --model_name="market_blstm" 对应

request.model_spec.signature_name = 'market_classification' # 这个signature_name  跟2.2模型导出中的market_classification 对应

text_list = ['吴亦凡同款 Sup扎染卫衣 全身顶级数码直喷 印花带做旧感 就是看起来脏脏的 一件衣服印花大几十块 完美还原面料为420G毛圈轻捉毛 质感很好',"360儿童5周年不止5折# 360儿童手表五周年&双十一特惠! 喜欢![失望]"]

x_train = np.array([pad_sequences(denoise(text)) for text in text_list])
request.inputs['textdata'].CopyFrom(
                  tf.contrib.util.make_tensor_proto(x_train, shape=[batch_size,64],dtype=tf.float32)) # shape跟 keras的model.input类型对应,且textdata对应2.2中的textdata
result = stub.Predict(request, 10.0)
reslist = result.outputs['market'].float_val
print(reslist)

结果如下:
[0.013646061532199383, 0.9863539338111877, 0.16853764653205872, 0.8314623832702637] 每两个是一对预测数据,例如0.013646061532199383, 0.9863539338111877表示分别表示text_list中第一条数据属于0类的概率为0.013646061532199383,1类的概率为0.9863539338111877

3.2 java案例

pom.xml文件中的依赖项:

<dependencies>
        <dependency>
            <groupId>com.yesup.oss</groupId>
            <artifactId>tensorflow-client</artifactId>
            <version>1.4-2</version>
        </dependency>

        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-netty</artifactId>
            <version>1.7.0</version>
        </dependency>

        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-tcnative-boringssl-static</artifactId>
            <version>2.0.7.Final</version>
        </dependency>

        <dependency>
            <groupId>com.huaban</groupId>
            <artifactId>jieba-analysis</artifactId>
            <version>1.0.2</version>
        </dependency>

        <dependency>
            <groupId>net.sf.json-lib</groupId>
            <artifactId>json-lib</artifactId>
            <version>2.4</version>
            <classifier>jdk15</classifier>
        </dependency>

        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
            <version>2.6</version>
        </dependency>
    </dependencies>

具体代码:

import com.huaban.analysis.jieba.JiebaSegmenter;
import com.huaban.analysis.jieba.WordDictionary;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import net.sf.json.JSONObject;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

public class TensorServClient {

    PredictionServiceGrpc.PredictionServiceBlockingStub stub = null;

    private JiebaSegmenter segmenter;
    private JSONObject json;

    private static int maxlen = 64;   //padding的最大长度
    private static int batch = 200;

    public TensorServClient(){
        ManagedChannel channel = ManagedChannelBuilder.forAddress("127.0.0.1",8500).usePlaintext(true).build();
        //这里还是先用block模式
        stub = PredictionServiceGrpc.newBlockingStub(channel);

        WordDictionary dictAdd = WordDictionary.getInstance();
        dictAdd.loadUserDict(Paths.get("jiebaextradic_java.dict"));//加载自定义词典
        segmenter = new JiebaSegmenter();

        try {
            json = LoadJsonFile.load("vocab_bgru.dict"); //加载词位置索引词典 ,格式:"中国":12045
        }catch (Exception ex){
            ex.printStackTrace();
        }
    }
    private ArrayList<Integer> denoise(String line){
        ArrayList<Integer>x_train_word_ids = new ArrayList<Integer>();
        line = line.replaceAll("(http|ftp|https):\\/\\/[\\w\\-_]+(\\.[\\w\\-_]+)+([\\w\\-\\.,@?^=%&amp;:/~\\+#]*[\\w\\-\\@?^=%&amp;/~\\+#])?","");
        line = line.replaceAll("\\{url(.*)网页链接\\}","");
        line = line.replaceAll("\\\\","").replaceAll("\\r|\\n","").replaceAll("https://","");
        ArrayList<String> wordjiebaList = (ArrayList<String>) segmenter.sentenceProcess(line);
        for (String word:wordjiebaList) {
            try {
                if (this.json.containsKey(word)){
                    x_train_word_ids.add(this.json.getInt(word));
                }
            }catch (Exception e){
                x_train_word_ids.add(0);
            }
        }
        return x_train_word_ids;
    }
    private float[] padSequences(ArrayList<Integer>x_train_word_ids){
        float []res=new float[maxlen];
        int len_x = x_train_word_ids.size();
        if (len_x>maxlen){
            for (int i = len_x-maxlen,j=0; i < len_x; i++,j++) {
                res[j]=x_train_word_ids.get(i);
            }
            return res;
        }else {
            for (int i = 0; i < len_x; i++) {
                res[maxlen-len_x+i]=x_train_word_ids.get(i);
            }
            return res;
        }
    }
    private float[][]gen_predict_data(String []textlist){
        float [][] predict_data = new float[batch][maxlen];
        for (int i = 0; i < textlist.length; i++) {
            predict_data[i]=padSequences(denoise(textlist[i]));
        }
        return predict_data;
    }

public void predict(String[] textlist){
        //        //创建请求
        Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
        //模型名称和模型方法名预设
        Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName("market_blstm");
        modelSpecBuilder.setSignatureName("market_classification");
        predictRequestBuilder.setModelSpec(modelSpecBuilder);

        //设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
        TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
        TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();

        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(batch));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(maxlen));

        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());

        float[][]featuresTensorData = gen_predict_data(textlist);

        for (int i = 0; i < featuresTensorData.length; ++i) {
            for (int j = 0; j < featuresTensorData[i].length; ++j) {
                tensorProtoBuilder.addFloatVal(featuresTensorData[i][j]);
            }
        }

        predictRequestBuilder.putInputs("textdata",tensorProtoBuilder.build());
        //访问并获取结果
        Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
        TensorProto result = predictResponse.toBuilder().getOutputsOrThrow("market");
        List<Float> reslist = result.getFloatValList();
}

public static void main(String[] args) throws Exception{
        long startTime = System.currentTimeMillis();
        TensorServClient tensorServClient = new TensorServClient();
        long midTime = System.currentTimeMillis();
        String[] textlist = {"吴亦凡同款 Sup扎染卫衣 全身顶级数码直喷 印花带做旧感 就是看起来脏脏的 一件衣服印花大几十块 完美还原面料为420G毛圈轻捉毛 质感很好","360儿童5周年不止5折# 360儿童手表五周年&双十一特惠! 喜欢![失望]",....};//这个数组的长度为 batch ,方便批处理
        tensorServClient.predict(textlist); 
 }
}

注:java 案例中textlist的长度为batch,每个位置上是一条文本;结果与python案例保持一致,亦是两个一对

上一篇下一篇

猜你喜欢

热点阅读