网站首页 文章专栏 tensorflow保存与恢复模型
tensorflow保存与恢复模型
创建于:2019-06-18 03:05:30 更新于:2024-11-22 10:43:58 羽瀚尘 779
未分类


ckpt模型与pb模型比较

  • ckpt模型可以重新训练,pb模型不可以(pb一般用于线上部署)

  • ckpt模型可以指定保存最近的n个模型,pb不可以

    保存ckpt模型


    保存路径必须带.ckpt这个后缀名,不能是文件夹,否则无法保存meta文件
    ”`python
    CKPT_PATH = ‘./model.ckpt’
    vgg16_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=‘vgg19’)
    outputs_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=‘outputs’)

    max_to_keep是指在文件夹中保存几个最近的模型

    saver = tf.train.Saver(vgg16_variables + outputs_variables, max_to_keep=1)

    saver.save(sess, CKPT_PATH)
    ”`

    恢复ckpt模型

    pyhon ckpt = tf.train.get_checkpoint_state('ckpt') if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) print('Restore from', ckpt.model_checkpoint_path) gstep = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

    保存pb模型


    保存为pb模型时要指明对外暴露哪些接口
    py graph_def = tf.get_default_graph().as_graph_def() output_graph_def = graph_util.convert_variables_to_constants( sess, graph_def, ['inputs','labels','keep_prob','accuracy'] ) with tf.gfile.GFile('save.pb', 'wb') as fid: serialized_graph = output_graph_def.SerializeToString() fid.write(serialized_graph)

    加载pb模型

    pb 格式模型保存与恢复相比于前面的 .ckpt 格式而言要稍微麻烦一点,但使用更灵活,特别是模型恢复,因为它可以脱离会话(Session)而存在,便于部署。

    加载步骤如下:

  1. tf.Graph()定义了一张新的计算图,与上面的计算图区分开
  2. ParseFromString将保存的计算图反序列化
  3. tf.import_graph_def导入一张计算图
  4. 新建Session,获取Tensor
  5. 使用模型进行预测

    py model_graph = tf.Graph() with model_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile('save.pb', 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') with tf.Session(graph=model_graph) as sess: inputs = tf.get_default_graph().get_tensor_by_name('inputs:0') labels = tf.get_default_graph().get_tensor_by_name('labels:0') keep_prob = tf.get_default_graph().get_tensor_by_name('keep_prob:0') accuracy = tf.get_default_graph().get_tensor_by_name('accuracy:0') batch_xs, batch_ys = mnist.test.next_batch(100) batch_xs = batch_xs.reshape([-1, 28, 28, 1]) acc = sess.run(accuracy, feed_dict={inputs: batch_xs, labels: batch_ys, keep_prob:1.0}) print('After restore sess from pb file, accuracy is ', acc)