网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型在部署到安卓端的时候出现各种问题。因此,本文会记录从PC端训练、导出到安卓端部署的各种细节。欢迎大家讨论、指教。
训练框架:tensorflow slim 关于tensorflow slim如何安装,这里不再赘述,大家自行百度解决。
# -*- coding: utf-8 -*- import cv2 import numpy as np from captcha.image import ImageCaptcha def generate_captcha(text='1'): """Generate a digit image.""" capt = ImageCaptcha(width=28, height=28, font_sizes=[24]) image = capt.generate_image(text) image = np.array(image, dtype=np.uint8) return image if __name__ == '__main__': output_dir = './datasets/images/' for i in range(50000): label = np.random.randint(0, 10) image = generate_captcha(str(label)) image_name = 'image{}_{}.jpg'.format(i+1, label) output_path = output_dir + image_name cv2.imwrite(output_path, image)
训练:本次训练我用tensorflow slim 搭建了一个七层卷积的网络,最后测试准确率在96%~99%左右,模型1.2M,适合在移动端部署。训练的时候我做了两点工作
inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs') ....... ....... prob_ = tf.identity(prob, name='prob')
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式 with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb f.write(constant_graph.SerializeToString())
# -*- coding: utf-8 -*- """Train a CNN model to classifying 10 digits. Example Usage: --------------- python3 train.py \ --images_path: Path to the training images (directory). --model_output_path: Path to model.ckpt. """ import cv2 import glob import numpy as np import os import tensorflow as tf import model from tensorflow.python.framework import graph_util flags = tf.app.flags flags.DEFINE_string('images_path', None, 'Path to training images.') flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.') FLAGS = flags.FLAGS def get_train_data(images_path): """Get the training images from images_path. Args: images_path: Path to trianing images. Returns: images: A list of images. lables: A list of integers representing the classes of images. Raises: ValueError: If images_path is not exist. """ if not os.path.exists(images_path): raise ValueError('images_path is not exist.') images = [] labels = [] images_path = os.path.join(images_path, '*.jpg') count = 0 for image_file in glob.glob(images_path): count += 1 if count % 100 == 0: print('Load {} images.'.format(count)) image = cv2.imread(image_file) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Assume the name of each image is imagexxx_label.jpg label = float(image_file.split('_')[-1].split('.')[0]) images.append(image) labels.append(label) images = np.array(images) labels = np.array(labels) return images, labels def next_batch_set(images, labels, batch_size=128): """Generate a batch training data. Args: images: A 4-D array representing the training images. labels: A 1-D array representing the classes of images. batch_size: An integer. Return: batch_images: A batch of images. batch_labels: A batch of labels. """ indices = np.random.choice(len(images), batch_size) batch_images = images[indices] batch_labels = labels[indices] return batch_images, batch_labels def main(_): inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs') labels = tf.placeholder(tf.int32, shape=[None], name='labels') cls_model = model.Model(is_training=True, num_classes=10) preprocessed_inputs = cls_model.preprocess(inputs)#预处理 prediction_dict = cls_model.predict(preprocessed_inputs) loss_dict = cls_model.loss(prediction_dict, labels) loss = loss_dict['loss'] postprocessed_dict = cls_model.postprocess(prediction_dict) classes = postprocessed_dict['classes'] prob = postprocessed_dict['prob'] classes_ = tf.identity(classes, name='classes') prob_ = tf.identity(prob, name='prob') acc = tf.reduce_mean(tf.cast(tf.equal(classes, labels), 'float')) global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(0.05, global_step, 150, 0.9) optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) train_step = optimizer.minimize(loss, global_step) saver = tf.train.Saver() images, targets = get_train_data(FLAGS.images_path) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for i in range(6000): batch_images, batch_labels = next_batch_set(images, targets) train_dict = {inputs: batch_images, labels: batch_labels} sess.run(train_step, feed_dict=train_dict) loss_, acc_,prob__,classes__ = sess.run([loss, acc, prob_,classes_], feed_dict=train_dict) train_text = 'step: {}, loss: {}, acc: {},classes:{}'.format( i+1, loss_, acc_,classes__) print(train_text) print (prob__) saver.save(sess, FLAGS.model_output_path) constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式 with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb f.write(constant_graph.SerializeToString()) if __name__ == '__main__': tf.app.run()
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(input_checkpoint,output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路径 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 #input_node_names = "inputs" output_node_names = "inputs,classes" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess: saver.restore(sess, input_checkpoint) #恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=input_graph_def,# 等于:sess.graph_def output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 f.write(output_graph_def.SerializeToString()) #序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点 # for op in graph.get_operations(): # print(op.name, op.values()) # 输入ckpt模型路径 input_checkpoint='model/model.ckpt' # 输出pb模型的路径 out_pb_path="frozen_model.pb" # 调用freeze_graph将ckpt转为pb freeze_graph(input_checkpoint,out_pb_path)
import tensorflow as tf #把pb文件路径改成自己的pb文件路径即可 path = "model2.pb" #如果是不知道自己的模型的输入输出节点,建议用tensorboard做可视化查看计算图,计算图里有输入输出的节点名称 inputs = ["inputs"] outputs = ["prob"] #转换pb模型到tflite模型 converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, outputs) #converter.post_training_quantize = True tflite_model = converter.convert() open("model3.tflite", "wb").write(tflite_model)
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
# -*- coding: utf-8 -*- """Evaluate the trained CNN model. Example Usage: --------------- python3 infrence_pb.py \ --frozen_graph_path: Path to model frozen graph. """ import numpy as np import tensorflow as tf from captcha.image import ImageCaptcha flags = tf.app.flags flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.') FLAGS = flags.FLAGS def generate_captcha(text='1'): capt = ImageCaptcha(width=28, height=28, font_sizes=[24]) image = capt.generate_image(text) image = np.array(image, dtype=np.uint8) return image def main(_): model_graph = tf.Graph() with model_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') with model_graph.as_default(): with tf.Session(graph=model_graph) as sess: inputs = model_graph.get_tensor_by_name('inputs:0') classes = model_graph.get_tensor_by_name('classes:0') prob = model_graph.get_tensor_by_name('prob:0') for i in range(10): label = np.random.randint(0, 10) image = generate_captcha(str(label)) image = image_np = np.expand_dims(image, axis=0) predicted_label,probs = sess.run([classes,prob], feed_dict={inputs: image_np}) print(predicted_label, ' vs ', label) print(probs) if __name__ == '__main__': tf.app.run()
# -*- coding:utf-8 -*- import os import cv2 import numpy as np import time import tensorflow as tf test_image_dir = './test_images/' #model_path = "./model/quantize_frozen_graph.tflite" model_path = "./model3.tflite" # Load TFLite model and allocate tensors. interpreter = tf.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() print(str(input_details)) output_details = interpreter.get_output_details() print(str(output_details)) #with tf.Session( ) as sess: if 1: file_list = os.listdir(test_image_dir) model_interpreter_time = 0 start_time = time.time() # 遍历文件 for file in file_list: print('=========================') full_path = os.path.join(test_image_dir, file) print('full_path:{}'.format(full_path)) img = cv2.imread(full_path ) res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) # 变成长784的一维数据 #new_img = res_img.reshape((784)) new_img = np.array(res_img, dtype=np.uint8) # 增加一个维度,变为 [1, 784] image_np_expanded = np.expand_dims(new_img, axis=0) image_np_expanded = image_np_expanded.astype('float32') # 类型也要满足要求 # 填装数据 model_interpreter_start_time = time.time() interpreter.set_tensor(input_details[0]['index'], image_np_expanded) # 注意注意,我要调用模型了 interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) model_interpreter_time += time.time() - model_interpreter_start_time # 出来的结果去掉没用的维度 result = np.squeeze(output_data) print('result:{}'.format(result)) #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded}))) # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字 #print('result:{}'.format( (np.where(result==np.max(result)))[0][0] )) used_time = time.time() - start_time print('used_time:{}'.format(used_time)) print('model_interpreter_time:{}'.format(model_interpreter_time))
private int[] ddims = {1, 3, 28, 28};
private static final String[] PADDLE_MODEL = {
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel1.txt")));
float[][] labelProbArray = new float[1][10];
// imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
// imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
// imgData.putFloat((((val & 0xFF) - 128f) / 128f));
imgData.putFloat(((val >> 16) & 0xFF) );
imgData.putFloat(((val >> 8) & 0xFF) );
imgData.putFloat((val & 0xFF) );
留一张测试图片,大家可以拿去测试,正确结果应该是0.0,安卓代码地址是这里,CSDN下载请搜索 anquangan
#coding:utf-8 import tensorflow as tf from tensorflow.python.framework import graph_util tf.reset_default_graph() # 重置计算图 output_graph_path = 'model3.pb' with tf.Session() as sess: tf.global_variables_initializer().run() output_graph_def = tf.GraphDef() # 获得默认的图 graph = tf.get_default_graph() with open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") # 得到当前图有几个操作节点 print("%d ops in the final graph." % len(output_graph_def.node)) tensor_name = [tensor.name for tensor in output_graph_def.node] print(tensor_name) print('---------------------------') # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型 #summaryWriter = tf.summary.FileWriter('log_graph/', graph) for op in graph.get_operations(): # print出tensor的name和值 print(op.name, op.values())
