其他分享
首页 > 其他分享> > 利用tensorflow-serving部署kashgari模型

利用tensorflow-serving部署kashgari模型

作者:互联网

原文链接:https://www.cnblogs.com/jclian91/p/11526547.html

本项目的data来自之前笔者标注的时间数据集,即标注出文本中的时间,采用BIO标注系统。chinese_wwm_ext文件夹为哈工大的预训练模型文件。
  model_train.py为模型训练的代码,主要功能是完成时间序列标注模型的训练,完整的代码如下:

-- coding: utf-8 --

time: 2019-09-12

place: Huangcun Beijing

import kashgari
from kashgari import utils
from kashgari.corpus import DataReader
from kashgari.embeddings import BERTEmbedding
from kashgari.tasks.labeling import BiLSTM_CRF_Model

模型训练

train_x, train_y = DataReader().read_conll_format_file(’./data/time.train’)
valid_x, valid_y = DataReader().read_conll_format_file(’./data/time.dev’)
test_x, test_y = DataReader().read_conll_format_file(’./data/time.test’)

bert_embedding = BERTEmbedding(‘chinese_wwm_ext_L-12_H-768_A-12’,
task=kashgari.LABELING,
sequence_length=128)

model = BiLSTM_CRF_Model(bert_embedding)

model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=1)

Save model

utils.convert_to_saved_model(model,
model_path=‘saved_model/time_entity’,
version=1)
  运行该代码,模型训练完后会生成saved_model文件夹,里面含有模型训练好后的文件,方便我们利用tensorflow/serving进行部署。接着我们利用tensorflow/serving来完成模型的部署,命令如下:

docker run -t --rm -p 8501:8501 -v “/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/” -e MODEL_NAME=time_entity tensorflow/serving
其中需要注意该模型所在的路径,路径需要写完整路径,以及模型的名称(MODEL_NAME),这在训练代码(train.py)中已经给出(saved_model/time_entity)。

接着我们使用tornado来搭建HTTP服务,帮助我们方便地进行模型预测,runServer.py的完整代码如下:

-- coding: utf-8 --

import requests
from kashgari import utils
import numpy as np
from model_predict import get_predict

import json
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options
import traceback

tornado高并发

import tornado.web
import tornado.gen
import tornado.concurrent
from concurrent.futures import ThreadPoolExecutor

定义端口为12333

define(“port”, default=16016, help=“run on the given port”, type=int)

模型预测

class ModelPredictHandler(tornado.web.RequestHandler):
executor = ThreadPoolExecutor(max_workers=5)

# get 函数
@tornado.gen.coroutine
def get(self):
    origin_text = self.get_argument('text')
    result = yield self.function(origin_text)
    self.write(json.dumps(result, ensure_ascii=False))

@tornado.concurrent.run_on_executor
def function(self, text):
    try:
        text = text.replace(' ', '')
        x = [_ for _ in text]

        # Pre-processor data
        processor = utils.load_processor(model_path='saved_model/time_entity/1')
        tensor = processor.process_x_dataset([x])

        # only for bert Embedding
        tensor = [{
            "Input-Token:0": i.tolist(),
            "Input-Segment:0": np.zeros(i.shape).tolist()
        } for i in tensor]

        # predict
        r = requests.post("http://localhost:8501/v1/models/time_entity:predict", json={"instances": tensor})
        preds = r.json()['predictions']

        # Convert result back to labels
        labels = processor.reverse_numerize_label_sequences(np.array(preds).argmax(-1))

        entities = get_predict('TIME', text, labels[0])

        return entities

    except Exception:
        self.write(traceback.format_exc().replace('\n', '<br>'))

get请求

class HelloHandler(tornado.web.RequestHandler):
def get(self):
self.write(‘Hello from lmj from Daxing Beijing!’)

主函数

def main():
# 开启tornado服务
tornado.options.parse_command_line()
# 定义app
app = tornado.web.Application(
handlers=[(r’/model_predict’, ModelPredictHandler),
(r’/hello’, HelloHandler),
], #网页路径控制
)
http_server = tornado.httpserver.HTTPServer(app)
http_server.listen(options.port)
tornado.ioloop.IOLoop.instance().start()

main()
  我们定义了tornado封装HTTP服务来进行模型预测,运行该脚本,启动模型预测的HTTP服务。接着我们再使用Python脚本才测试下模型的预测效果以及预测时间,预测的代码脚本的完整代码如下:

import time
import json
import requests

t1 = time.time()
texts = [‘据《新闻联播》报道,9月9日至11日,中央纪委书记赵乐际到河北调研。’,
‘记者从国家发展改革委、商务部相关方面获悉,日前美方已决定对拟于10月1日实施的中国输美商品加征关税措施做出调整,中方支持相关企业从即日起按照市场化原则和WTO规则,自美采购一定数量大豆、猪肉等农产品,国务院关税税则委员会将对上述采购予以加征关税排除。’,
‘据印度Zee新闻网站12日报道,亚洲新闻国际通讯社援引印度军方消息人士的话说,9月11日的对峙事件发生在靠近班公错北岸的实际控制线一带。’,
‘儋州市决定,从9月开始,对城市低保、农村低保、特困供养人员、优抚对象、领取失业保险金人员、建档立卡未脱贫人口等低收入群体共3万多人,发放猪肉价格补贴,每人每月发放不低于100元补贴,以后发放标准,将根据猪肉价波动情况进行动态调整。’,
‘9月11日,华为心声社区发布美国经济学家托马斯.弗里德曼在《纽约时报》上的专栏内容,弗里德曼透露,在与华为创始人任正非最近一次采访中,任正非表示华为愿意与美国司法部展开话题不设限的讨论。’,
‘造血干细胞移植治疗白血病技术已日益成熟,然而,通过该方法同时治愈艾滋病目前还是一道全球尚在攻克的难题。’,
‘英国航空事故调查局(AAIB)近日披露,今年2月6日一趟由德国法兰克福飞往墨西哥坎昆的航班上,因飞行员打翻咖啡使操作面板冒烟,导致飞机折返迫降爱尔兰。’,
‘当地时间周四(9月12日),印度尼西亚财政部长英卓华(Sri Mulyani Indrawati)明确表示:特朗普的推特是风险之一。’,
‘华中科技大学9月12日通过其官方网站发布通报称,9月2日,我校一硕士研究生不幸坠楼身亡。’,
‘微博用户@ooooviki 9月12日下午公布发生在自己身上的惊悚遭遇:一个自称网警、名叫郑洋的人利用职务之便,查到她的完备的个人信息,包括但不限于身份证号、家庭地址、电话号码、户籍变动情况等,要求她做他女朋友。’,
‘今天,贵阳取消了汽车限购,成为目前全国实行限购政策的9个省市中,首个取消限购的城市。’,
‘据悉,与全球同步,中国区此次将于9月13日于iPhone官方渠道和京东正式开启预售,京东成Apple中国区唯一官方授权预售渠道。’,
‘根据央行公布的数据,截至2019年6月末,存款类金融机构住户部门短期消费贷款规模为9.11万亿元,2019年上半年该项净增3293.19亿元,上半年增量看起来并不乐观。’,
‘9月11日,一段拍摄浙江万里学院学生食堂的视频走红网络,视频显示该学校食堂不仅在用餐区域设置了可以看电影、比赛的大屏幕,还推出了“一人食”餐位。’,
‘当日,在北京举行的2019年国际篮联篮球世界杯半决赛中,西班牙队对阵澳大利亚队。’,
]

print(len(texts))

for text in texts:
url = ‘http://localhost:16016/model_predict?text=%s’ % text
req = requests.get(url)
print(json.loads(req.content))

t2 = time.time()

print(round(t2-t1, 4))
  运行该代码,输出的结果如下:(预测文本中的时间)

一共预测15个句子。
[‘9月9日至11日’]
[‘日前’, ‘10月1日’, ‘即日’]
[‘12日’, ‘9月11日’]
[‘9月’]
[‘9月11日’]
[]
[‘近日’, ‘今年2月6日’]
[‘当地时间周四(9月12日)’]
[‘9月12日’, ‘9月2日’]
[‘9月12日下午’]
[‘今天’, ‘目前’]
[‘9月13日’]
[‘2019年6月末’, ‘2019年上半年’, ‘上半年’]
[‘9月11日’]
[‘当日’, ‘2019年’]
预测耗时: 15.1085s.
模型预测的效果还是不错的,但平均每句话的预测时间为1秒多,模型预测时间还是稍微偏长,后续笔者将会研究如何缩短模型预测的时间。
东莞网站建设www.zg886.cn

标签:serving,tornado,text,time,kashgari,12,import,tensorflow,model
来源: https://blog.csdn.net/ting2909/article/details/100886934