tensorflow_tflite专题
作者:互联网
tensorflow_tflite专题
本文章主要包括两大问题:
tflite的转换:如何转换得到tflite?
tflite的测试:如何测试或者说如何在PC端使用tflite?
问题一:如何转换得到tflite
分为两个过程,步骤:cheakpoint→pb模型→tflite模型
- step1:cheakpoint→tflite_graph.pb:
使用object_detection的export_tflite_ssd_graph.py,结果生成tflite_graph.pb和tflite_graph.pbtxt两个文件
超参数:
"output_directory":输出的文件夹
"pipeline_config_path":网络配置文件
"trained_checkpoint_prefix":你的cheakpoint文件
- step2:tflite_graph.pb→out_put.tflite:
使用convert.py程序讲pb转换为tflite,这里的pb是上一步转换得到了,不能是其他来源的pb模型
import tensorflow as tf
# 需要配置
in_path = "tflite_graph.pb"
# 模型输入节点对于object_detection是固定的,不需改动,但是shape是和网络有关
input_tensor_name = ["normalized_input_image_tensor"]
input_tensor_shape = {"normalized_input_image_tensor":[1,256,256,3]}
# 模型输出节点,对于object_detection是固定的,不需改动
classes_tensor_name = ['TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']
converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path,input_tensor_name, classes_tensor_name,input_tensor_shape)
converter.allow_custom_ops=True
converter.post_training_quantize = True
tflite_model = converter.convert()
open("output.tflite", "wb").write(tflite_model)
print("done")
问题二:如何测试或者说如何在PC端使用tflite?
这里给出代码:
import numpy as np
import tensorflow as tf
import cv2 #用来读取图片并进行预处理
import glob #读取某文件夹所有测试图片
import time #主要是用来计算推理花费时间
# Load TFLite model and allocate tensors.
model_path="output_fp16.tflite" #tflite路径
interpreter = tf.lite.Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details) #在这里可以看到tflite的输入输出的节点信息
def detection(img_src):
img = cv2.resize(img_src, (256, 256))
img = img / 128 - 1
input_data = np.expand_dims(img, 0)
input_data = input_data.astype(np.float32)
#以上是对图片经行尺寸变换、归一化、添加维度和类型转换,以便和输入节点对应
index = input_details[0]['index']
interpreter.set_tensor(index, input_data)
interpreter.invoke() #启动
output0 = interpreter.get_tensor(output_details[0]['index']) # bbox
output1 = interpreter.get_tensor(output_details[1]['index']) # bbox
output2 = interpreter.get_tensor(output_details[2]['index']) # bbox
output3 = interpreter.get_tensor(output_details[3]['index']) # 概率
#在这里你可以通过print查看4个输出的信息
#分别时object_detection的信息:
#对于我来讲,人脸检测不涉及类别,所以我只用到
# output0:位置信息
# output2:对应的概率
#我只要概率最大的人脸,且概率>0.6保持,否则讲概率置为0
output3=output3[0][0] if output3[0][0] > 0.6 else 0
return bbox,output3 #返回概率信息和其位置信息
imgs_path = glob.glob('../../test_iamge/*')
for img_path in imgs_path:
t1=time.time()
img=cv2.imread(img_path)
sp = img.shape
bbox,confidence=detection(img)
if confidence!=0:
print('置信度=',confidence,' bbox=',bbox,end=' ')
y1 = int(bbox[0][0][0] * sp[0])
x1 = int(bbox[0][0][1] * sp[1])
y2 = int(bbox[0][0][2] * sp[0])
x2 = int(bbox[0][0][3] * sp[1])
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 3)
print('time=',time.time()-t1)
cv2.namedWindow(str(confidence*100)[2:6]+'%', 0)
cv2.imshow(str(confidence*100)[2:6]+'%', img)
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print('time=',time.time()-t1)
标签:专题,tensor,img,tflite,bbox,input,tensorflow,details 来源: https://www.cnblogs.com/thgpddl/p/13550417.html