一、模型格式

Tensorflow的保存分为三种:1. checkpoint模式;2. pb模式;3. saved_model模式。

Tensorflow中模型的保存一般使用tf.train.Saver()定义的存储器对象来保存魔心那个,并得到如下的文件:

checkpoint
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

# "model.ckpt" 是文件名
#保存时候生成三个文件,data对应的权重,meta是深度学习结构图,index对应的图和权重的索引,
# checkpoint是所有保存的model.ckpt的总括
# ckpt是一个模型快照

 二、意义

pb文件格式与语言无关,好使用,其次只有一个文件。空间也小。

三、ckpt转pb格式

原理就是根据输出节点(重点节点不是张量),保存所有子图。

tensorflow 1.15版本

with tf.Session(config=config) as sess:
    #加载ckpt格式模型
    model_file = tf.train.latest_checkpoint(FLAGS.model_path)
    saver = tf.train.import_meta_graph("{}.meta".format(model_file))
    saver.restore(sess, model_file)
    
    def frozen():
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()

        frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, ["representation/qs_y_raw","representation/q_y_raw"])
  

       # 保存模型
        output_graph="pb_model/frozen_model.pb"
        with tf.gfile.GFile(output_graph, "wb") as f:
                f.write(frozen_graph_def.SerializeToString())  # 序列化输出
            # 得到当前图有几个操作节点
        print("%d ops in the final graph." % len(frozen_graph_def.node))

    #     for op in graph.get_operations():
    #             print(op.name,"op.values:",op.values())
        print("model have been frozen... ...")

 

 

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐