tensorflow tfserving 部署记录
作者:互联网
1,环境
keras2.4.3
tensorflow2.2
模型为keras的h5格式
keras-bert 0.88
wsl2下,docker环境部署,nividia-container-toolkit
wsl2安装nividia-container-toolkit参考,win10版本请务必更新为21H2以上
https://docs.microsoft.com/zh-cn/windows/wsl/tutorials/gpu-compute
2,模型转PB格式
tfserving部署需要pb格式的模型
在keras下即可转换,转换代码
点击查看代码
def get_mycustom_objects():
# 模型在训练的时候会有自定义的一系列函数
# 在从文件加载模型的时候,需要将其先前使用的一系列函数配件也传递给它
my_objects = {
'acc_top2': acc_top2,
"metric_precision": metric_precision,
"metric_recall": metric_recall,
"metric_F1score": metric_F1score
}
from keras_bert import get_custom_objects
custom_objects = get_custom_objects()
custom_objects.update(my_objects)
return custom_objects
def h5_to_pb(h5_model_path):
"""
将h5模型转为pb格式
:param h5_model_path:
:return:
"""
model = load_model(h5_model_path, compile=False, custom_objects=get_mycustom_objects())
model.save('trained_model/1/', save_format='tf')
print("转换结束")
转换结束后,模型是一个文件夹,以数字命名,表示版本,从0开始
3,查看PB模型的输入
cmd切换到PB模型所在目录,输入命令:
saved_model_cli show --dir=./ --all
命令输出如下,其中input_1和input_2为输入,该模型为bert NLP分类模型,两个输入
输出为dense_1,但是后面不需要注意这个
serving_default为后面python post数据中注释的signature_name
按h5_to_pb导出h5模型为pb模型,没有对模型signature_name做定义,默认就是serving_default,所以post的时候可以不传,除非你对pb模型自定义了
点击查看命令输出
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input_1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1)
name: serving_default_input_1:0
inputs['input_2'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1)
name: serving_default_input_2:0
The given SavedModel SignatureDef contains the following output(s):
outputs['dense_1'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 4)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
点击查看代码
import json
import requests, codecs
from keras_bert import Tokenizer
from keras.utils import to_categorical
import numpy as np
model_name ='xizang'
maxlen = 180
dict_path = r'bert/vocab.txt'
label2int = {
"0": 0,
"1": 1,
"2": 2,
"?": 3,
"?": 3,
"3": 3,
"4": 1
}
numofcategories = len(set(label2int.values()))
# 重写tokenizer
class OurTokenizer(Tokenizer):
def _tokenize(self, text):
R = []
for c in text:
if c in self._token_dict:
R.append(c)
elif self._is_space(c):
R.append('[unused1]') # 用[unused1]来表示空格类字符
else:
R.append('[UNK]') # 不在列表的字符用[UNK]表示 UNK是unknown的意思
return R
def get_token_dict():
"""
# 将词表中的字编号转换为字典
:return: 返回自编码字典
"""
token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
return token_dict
def text2input(text_list):
tokenizer = OurTokenizer(get_token_dict())
X1, X2, Y = [], [], []
for text in text_list:
text = text[:maxlen]
x1, x2 = tokenizer.encode(first=text)
X1.append(x1)
X2.append(x2)
return [X1, X2]
if __name__ == '__main__':
textlist = ["芜湖起飞","大威天龙"]
x = text2input(textlist)
print(x)
inputs = {}
input_data = {
# "signature_name": 'serving_default',
# tfserving 有两种输入格式,一种是row/instances行格式,一种是column/inputs格式,本例使用第二种格式
# column/inputs格式是,inputs为dict{},每一个输入样本有两个输入input_1和input_2,将每一个样本的相同输入放在一个列表
# row/instances格式是,instances为list[],包含一个个instance,每一个instance是一个完整的样本输入,包含input_1和input_2
# 参考https://www.tensorflow.org/tfx/serving/api_rest#request_format_2
"inputs": {
"input_1": x[0], # input_1列表
"input_2": x[1], # input_2列表
},
}
headers = {"content-type": "application/json"}
data = json.dumps(input_data, indent=None)
url = "http://192.168.70.163:8501/v1/models/" + model_name + ":predict"
json_response = requests.post(url, data=data, headers=headers)
# print(json_response.content)
outputs = np.array(json.loads(json_response.text)['outputs'])
print(outputs)
y = [np.argmax(s) for s in outputs]
print(y)
标签:serving,name,部署,objects,dict,input,tensorflow,model,tfserving 来源: https://www.cnblogs.com/lxzbky/p/16348181.html