其他分享
首页 > 其他分享> > tensorflow/serving部署keras模型

tensorflow/serving部署keras模型

作者:互联网

之前写了一篇tensorflow/serving部署tensorflow模型的文章,记录了详细的操作步骤与常见的错误及解决方案,具体见:TensorFlow Serving模型转换与部署

本文主要记录tensorflow/serving部署keras模型过程中的一些重要步骤,以便后续查阅。

我们在keras中保存模型通常用model.save或者model.save_weights函数。
其中,model.save函数保存的模型往往比的是模型的结构与权重,而model.save_weights函数保存的仅仅是模型的结构,因此model.save函数保存的模型往往比model.save_weights函数保存的模型要大些。

在前一篇tensorflow/serving介绍中TensorFlow Serving模型转换与部署,我们知道tensorflow/serving需要pb格式的模型,而本篇文章我们讨论的keras模型是.h5.weights格式的,因此,首先我们需要将.h5.weights格式的keras模型转换为tensorflow/serving框架可识别的pb格式模型,转换代码如下:

def keras_model_to_tfs(model, export_path):
    signature = tf.saved_model.signature_def_utils.predict_signature_def(
        inputs={'input_x': model.input}, 
        outputs={'output_y': model.output}
    )
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
    builder.add_meta_graph_and_variables(
        sess=K.get_session(),
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature,
        },
        legacy_init_op=legacy_init_op)
    builder.save()
    print('Build done.')

简要说明一下keras_model_to_tfs函数的参数
model:导入的keras模型,用keras的load_modelload_weights导入的模型
export_path:转换成pb格式模型后的保存路径

模型转换完成后,剩下的工作就是部署tensorflow/serving框架,并利用grpc接口调用模型预测。
关于具体的tensorflow/serving的部署,可参考之前文章:TensorFlow Serving模型转换与部署,预测代码在之前那篇文章中也有,本文再次贴出一个。

def tfserving_grpc(title, content):
    content = content or title
    content = filter_waste(content)
    model_dir = os.path.join(project_path, 'models_weights')
    with open(os.path.join(model_dir, 'tokenizer.plk'), 'rb') as f:
        tokenizer = pickle.load(f)
    x = tokenizer.texts_to_sequences([jieba.lcut(content)])
    x = x[0]
    if len(x) > MAX_LEN:
        x = x[:MAX_LEN]
    else:
        x = x + [0] * (MAX_LEN - len(x))

    # ip地址为部署tensorflow/serving的IP
    channel = grpc.insecure_channel('xx.xx.xx.xx:8500')  
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'new_yq_model'
    # request.model_spec.version.value = 1000001
    request.model_spec.signature_name = 'serving_default'

    request.inputs["input_x"].CopyFrom(tf.contrib.util.make_tensor_proto(np.array([x], dtype=np.float)))
    response = stub.Predict(request, 10.0)

    results = {}
    for key in response.outputs:
        tensor_proto = response.outputs[key]
        results[key] = tf.contrib.util.make_ndarray(tensor_proto)

    return results

最后给一个main函数的整体过程代码。

model = build_model(len(tokenizer.index_word))
model.load_weights(os.path.join(model_dir, 'best_model.weights'))
model.summary()
export_path = './tfs_models'
keras_model_to_tfs(model, export_path)

参考

使用tensorflow serving部署keras模型(tensorflow 2.0.0)
keras、tensorflow serving踩坑记

标签:serving,keras,模型,path,tensorflow,model
来源: https://blog.csdn.net/tianyunzqs/article/details/116303326