tensorflow

tf2.0 tfserving

2021-01-12  本文已影响0人  夕一啊

2018年写过tf保存为pb使用tfserving,现在发现tf2.0环境运行不了了,于是重新研究下
官方例子也变了,使用tf.compat 兼容api实现
简化了官方版本,更清晰简洁,如下所示:

import tensorflow as tf

def export():
    export_path = "model/half_plus_ten/1"
    with tf.compat.v1.keras.backend.get_session() as sess:
        # 定义模型,参数、输入输入
        a = tf.Variable(100.0)
        b = tf.Variable(0.05)
        x = tf.compat.v1.placeholder(tf.float32)
        y = tf.add(tf.multiply(a, x), b)
        sess.run(tf.compat.v1.global_variables_initializer())

        # 存储为pb格式
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
        #输入输出必须是tensor,签名化       
        inputs = tf.compat.v1.saved_model.utils.build_tensor_info(x)
        outputs = tf.compat.v1.saved_model.utils.build_tensor_info(y)
        prediction_signature = (
            tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
                inputs={'input': inputs},
                outputs={'output': outputs},
                method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME))

        builder.add_meta_graph_and_variables(
            sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict':
                    prediction_signature,         
        # 不知道为什么需要两次签名,但是少了下面这个会报错 
        #"error": "Serving signature name: \"serving_default\" not found in signature def"             
                    tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                            prediction_signature,
            },
            main_op=tf.compat.v1.tables_initializer(),
            strip_default_attrs=True)

        builder.save()


if __name__ == "__main__":
    export()

拉tfservingdocker起服务

docker run -t --rm -p 8501:8501 \
   -v "$(pwd)/model/half_plus_ten:/models/half_plus_ten" \
   -e MODEL_NAME=half_plus_ten \
   tensorflow/serving

调用服务

curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_ten:predict

keras模型保存成pb更简单, 一行代码解决,注意要使用tensorflow.python.keras

model.save('model/keras_model/1', save_format='tf')

查看保存好的pb模型细节

saved_model_cli show --dir model/fm_item/1 --all

可以看到pb模型输入输出

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['item_id_hash_pos'] tensor_info:
        dtype: DT_INT32
        shape: (-1, 1)
        name: serving_default_item_id_hash_pos:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['lambda_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 8)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

使用grpc调用服务:

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import grpc
import tensorflow as tf

import numpy as np

def request_server(server_url):
    '''
    用于向TensorFlow Serving服务请求推理结果的函数。
    :param img_resized: 经过预处理的待推理图片数组,numpy array,shape:(h, w, 3)
    :param server_url: TensorFlow Serving的地址加端口,str,如:'0.0.0.0:8500' 
    :return: 模型返回的结果数组,numpy array
    '''
    # Request.
    channel = grpc.insecure_channel(server_url)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = "half_plus_ten"  # 模型名称,启动容器命令的model_name参数
    request.model_spec.signature_name = "serving_default"  # 签名名称,刚才叫你记下来的
    # "input_1"是你导出模型时设置的输入名称,刚才叫你记下来的
    x_data = [[3428],[968],[3],[2]]
    request.inputs["item_id_hash_pos"].CopyFrom(tf.make_tensor_proto(x_data, dtype=tf.int32))
    response = stub.Predict(request, 5.0)  # 5 secs timeout
    print(response.outputs["lambda_1"])
    return np.asarray(response.outputs["lambda_1"].float_val) # fc2为输出名称,刚才叫你记下来的


if __name__ == "__main__":
    print(request_server("0.0.0.0:8500"))
上一篇 下一篇

猜你喜欢

热点阅读