Tensorflow静态图pb(frozen graph)模型保存与调用
作者:互联网
pb模型保存
基于tf2
model = ...
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./frozen_models",
name="frozen_graph.pb",
as_text=False)
基于keras (tf1)
from tensorflow.keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = graph_util.convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names)
if not clear_devices:
for node in frozen_graph.node:
node.device = "/GPU:0"
return frozen_graph
# load model
model = keras.models.model_from_json(...)
# save pb model
out_path = 'model.pb'
input_names = [n.op.name for n in model.inputs]
output_names = [n.op.name for n in model.outputs]
print(input_names, output_names)
frozen_graph = freeze_session(K.get_session(), output_names=output_names,clear_devices=clear_devices)
with open(out_path, "wb") as f:
f.write(frozen_graph.SerializeToString())
模型调用
这里以tf1为例:
from tensorflow.compat.v1 import Graph, GraphDef, import_graph_def, Session
from tensorflow.compat.v1.gfile import GFile
frozen_graph = "model.pb"
# import graph
with GFile(frozen_graph, "rb") as f:
graph_def = GraphDef()
graph_def.ParseFromString(f.read())
with Graph().as_default() as graph:
import_graph_def(graph_def,
input_map=None,
return_elements=None,
name=""
)
# set input output
x = graph.get_tensor_by_name("input:0")
y1 = graph.get_tensor_by_name("output1:0")
y2 = graph.get_tensor_by_name("output1:0")
sess = Session(graph=graph)
# get batch_input
batch_image = np.zeros([1, 512, 512, 3])
# get ...
# predict
feed_dict_testing = {x: batch_image}
output1, output2 = sess.run([y1, y2], feed_dict=feed_dict_testing)
标签:frozen,graph,pb,names,input,model,def 来源: https://blog.csdn.net/dou3516/article/details/110871797