网站首页 文章专栏 tensorflow保存与恢复模型
ckpt模型可以重新训练,pb模型不可以(pb一般用于线上部署)
ckpt模型可以指定保存最近的n个模型,pb不可以
保存路径必须带.ckpt这个后缀名,不能是文件夹,否则无法保存meta文件
”`python
CKPT_PATH = ‘./model.ckpt’
vgg16_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=‘vgg19’)
outputs_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=‘outputs’)
saver = tf.train.Saver(vgg16_variables + outputs_variables, max_to_keep=1)
saver.save(sess, CKPT_PATH)
”`
pyhon
ckpt = tf.train.get_checkpoint_state('ckpt')
if ckpt:
saver.restore(sess, ckpt.model_checkpoint_path)
print('Restore from', ckpt.model_checkpoint_path)
gstep = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
保存为pb模型时要指明对外暴露哪些接口
py
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['inputs','labels','keep_prob','accuracy']
)
with tf.gfile.GFile('save.pb', 'wb') as fid:
serialized_graph = output_graph_def.SerializeToString()
fid.write(serialized_graph)
pb 格式模型保存与恢复相比于前面的 .ckpt 格式而言要稍微麻烦一点,但使用更灵活,特别是模型恢复,因为它可以脱离会话(Session)而存在,便于部署。
加载步骤如下:
tf.Graph()
定义了一张新的计算图,与上面的计算图区分开
ParseFromString
将保存的计算图反序列化
tf.import_graph_def
导入一张计算图
Session
,获取Tensor
py
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile('save.pb', 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with tf.Session(graph=model_graph) as sess:
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
labels = tf.get_default_graph().get_tensor_by_name('labels:0')
keep_prob = tf.get_default_graph().get_tensor_by_name('keep_prob:0')
accuracy = tf.get_default_graph().get_tensor_by_name('accuracy:0')
batch_xs, batch_ys = mnist.test.next_batch(100)
batch_xs = batch_xs.reshape([-1, 28, 28, 1])
acc = sess.run(accuracy, feed_dict={inputs: batch_xs, labels: batch_ys, keep_prob:1.0})
print('After restore sess from pb file, accuracy is ', acc)