tf2.0 tfserving
2021-01-12 本文已影响0人
夕一啊
2018年写过tf保存为pb使用tfserving,现在发现tf2.0环境运行不了了,于是重新研究下
官方例子也变了,使用tf.compat 兼容api实现
简化了官方版本,更清晰简洁,如下所示:
import tensorflow as tf
def export():
export_path = "model/half_plus_ten/1"
with tf.compat.v1.keras.backend.get_session() as sess:
# 定义模型,参数、输入输入
a = tf.Variable(100.0)
b = tf.Variable(0.05)
x = tf.compat.v1.placeholder(tf.float32)
y = tf.add(tf.multiply(a, x), b)
sess.run(tf.compat.v1.global_variables_initializer())
# 存储为pb格式
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
#输入输出必须是tensor,签名化
inputs = tf.compat.v1.saved_model.utils.build_tensor_info(x)
outputs = tf.compat.v1.saved_model.utils.build_tensor_info(y)
prediction_signature = (
tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'input': inputs},
outputs={'output': outputs},
method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
'predict':
prediction_signature,
# 不知道为什么需要两次签名,但是少了下面这个会报错
#"error": "Serving signature name: \"serving_default\" not found in signature def"
tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
prediction_signature,
},
main_op=tf.compat.v1.tables_initializer(),
strip_default_attrs=True)
builder.save()
if __name__ == "__main__":
export()
拉tfservingdocker起服务
docker run -t --rm -p 8501:8501 \
-v "$(pwd)/model/half_plus_ten:/models/half_plus_ten" \
-e MODEL_NAME=half_plus_ten \
tensorflow/serving
调用服务
curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_ten:predict
keras模型保存成pb更简单, 一行代码解决,注意要使用tensorflow.python.keras
model.save('model/keras_model/1', save_format='tf')
查看保存好的pb模型细节
saved_model_cli show --dir model/fm_item/1 --all
可以看到pb模型输入输出
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['item_id_hash_pos'] tensor_info:
dtype: DT_INT32
shape: (-1, 1)
name: serving_default_item_id_hash_pos:0
The given SavedModel SignatureDef contains the following output(s):
outputs['lambda_1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 8)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
使用grpc调用服务:
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import grpc
import tensorflow as tf
import numpy as np
def request_server(server_url):
'''
用于向TensorFlow Serving服务请求推理结果的函数。
:param img_resized: 经过预处理的待推理图片数组,numpy array,shape:(h, w, 3)
:param server_url: TensorFlow Serving的地址加端口,str,如:'0.0.0.0:8500'
:return: 模型返回的结果数组,numpy array
'''
# Request.
channel = grpc.insecure_channel(server_url)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = "half_plus_ten" # 模型名称,启动容器命令的model_name参数
request.model_spec.signature_name = "serving_default" # 签名名称,刚才叫你记下来的
# "input_1"是你导出模型时设置的输入名称,刚才叫你记下来的
x_data = [[3428],[968],[3],[2]]
request.inputs["item_id_hash_pos"].CopyFrom(tf.make_tensor_proto(x_data, dtype=tf.int32))
response = stub.Predict(request, 5.0) # 5 secs timeout
print(response.outputs["lambda_1"])
return np.asarray(response.outputs["lambda_1"].float_val) # fc2为输出名称,刚才叫你记下来的
if __name__ == "__main__":
print(request_server("0.0.0.0:8500"))