Tensorflow 获取model中的变量列表
作者:互联网
1、动态获取
(1)朴素获取法
1) 朴素获取可训练变量:t_vars = tf.trainable_variables()
2)朴素获取全部变量,包含声明training=False变量:all_vars = tf.global_variables()
(2)使用tensorflow.contrib.slim
1) 获取常规变量(是slim里面与model变量对应的一个类型):regular_variables = slim.get_variables()
2)直接获取:vars = slim.get_variables_to_restore()
3)slim用于筛选方法
a. 通过name筛选: variables = slim.get_variables_by_name("d_")
b. 通过name后缀筛选:variables = slim.get_variables_by_suffix("_b")
c. 通过namespace筛选:variables = slim.get_variables(scope="layer1")
d. 通过include和exclude筛选
d0. variables_to_restore = slim.get_variables_to_restore(include=["d_"])
d1. variables_to_restore = slim.get_variables_to_restore(exclude=["_w"])
(3) 离线获取(从一个已保存好的模型中获取var_list)
1) 离线文件: checkpoint、model.data-xxxx、model.index、model.meta
2) 将离线文件载入当前环境,变成动态获取
#记住,要先清空现有的图
#不然的话import_meta_graph会把原model里面的数据追加到现有的model中
#一片混乱
tf.reset_default_graph()
with tf.Session(graph=tf.get_default_graph()) as sess:
new_saver = tf.train.import_meta_graph('e:/mytrain/results/20190227_01/model/model.meta')
new_saver.restore(sess, 'e:/mytrain/results/20190227_01/model/model')
#加载进来之后还不是为所欲为
var_list=tf.global_variables()
3) 直接从离线文件获取
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
#文件夹地址改成自己的
model_dir="'e:\\20190227_01\\mytrain\\results\\20190227_01\\model"
ckpt = tf.train.get_checkpoint_state(model_dir)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
#返回一个dict= {'name':[shape] }
#例如 'd_w2/Adam':[4, 4, 32, 64]
var_to_shape_map = reader.get_variable_to_shape_map()
#我们可以用遍历的方式,取出字典里所有的key
for key in var_to_shape_map:
print(key) #key是str类型的
#再用key去找这个tensor的值
a=reader.get_tensor(key)
print(type(a)) #输出: <class 'numpy.ndarray'>
标签:get,variables,slim,列表,获取,tf,Tensorflow,model 来源: https://blog.csdn.net/NOT_GUY/article/details/118417634