tensorflow保存模型、恢复模型
作者:互联网
1、模型训练(部分代码):
X = tf.placeholder(tf.float64,X_data.shape,name='X') Y = tf.placeholder(tf.float64,Y_data.shape,name='Y') epoch_num = 500 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) loss_data = [] # 创建FileWriter对象,用当前计算图初始化 writer = tf.summary.FileWriter('./summary/', sess.graph) # 保存模型 saver_path = './model/checkpoint/model.ckpt' # 模型保存路径 saver = tf.train.Saver() # 新建Saver()对象 for i in range(1,epoch_num+1): _, loss = sess.run([optimizer,loss_func],feed_dict={X:X_data,Y:Y_data}) loss_data.append(loss) saved_path = saver.save(sess, saver_path) # 保存模型 print("epoch:%d,loss:%.4g" % (i,loss)) # 关闭FileWriter writer.close()
2、保存模型
# 模型保存路径 saver_path = './model/checkpoint/model.ckpt' # 新建Saver()对象 saver = tf.train.Saver() # 保存模型 saved_path = saver.save(sess, saver_path)
执行之后,在目录./model/checkpoint/model.ckpt下,生成模型相关文件,如图:
3、恢复模型并使用模型、变量
meta_path = './model/checkpoint/model.ckpt.meta' model_path = './model/checkpoint/model.ckpt' # 导入计算图 saver = tf.train.import_meta_graph(meta_path) config = tf.ConfigProto() with tf.Session(config=config) as sess: # 恢复模型 saver.restore(sess, model_path) # 此时默认图就是导入的图 graph_restore = tf.get_default_graph() # 恢复变量 W = graph_restore.get_tensor_by_name('W:0') b = graph_restore.get_tensor_by_name('b:0') # 预测模型 predict_func = tf.matmul(test_data, W) predict_value = sess.run([predict_func],feed_dict={x:test_data})
标签:sess,模型,保存,path,tf,tensorflow,model,data,saver 来源: https://blog.csdn.net/qq_32172681/article/details/96315505