模型的保存分三种类型
- 知道模型结构,单纯保存变量
- 不知道模型结构,保存模型和变量
- 不需要再改变量,只要常量化的模型(“冻结”)
第一种用于训练的存档,并且临时恢复,这个时候用户是把训练需要的网络结构在代码里面构造好了的,只是在一定的时间下需要暂时保存网络中的变量,为了在崩溃之后继续训练。所以自然而然会有一个问题,如果我用 Python 写的代码,需要在 C++ 当中恢复,我需要知道你的模型结构,才能恢复,这个最蠢的办法是用 C++ 把你的网络结构再构造一遍,但我们按照统一的协议(比如 Protobuf)确定网络结构,就可以直接从标准序列化的数据中解析网络结构,这就是第二种情况,独立于语言,模型和变量一起保存的情况。然后如果碰到我们不需要再训练了,比如只是把这个模型进行部署,不需要改变相关的变量,那么其实只要一个带常量的模型就可以,这就是第三种情况,把变量冻结的正向传播模型。接下来会依次解释这几种情况的工作方式。
除了这些以外,针对用于服务的模型还可以做很多的优化。
存档
存档只是单纯的保存变量,并且能够恢复,可以在一定的迭代次数以后保存变量,并且从任意一个存档开始重新训练。以两个变量加减 1 为例。
1 | import tensorflow as tf |
可以在 /tmp/tf-test
下面看到这几个文件 checkpoint model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta
。
可以通过脚本观察保存的变量 python $tensorflow-src/tensorflow/python/tools/inspect_checkpoint.py --file_name=/tmp/tf-test/model.ckpt --all_tensors
得到保存的变量的内容,注意 model.ckpt
这个只是文件前缀。
1 | tensor_name: v1 |
如果要恢复的话,可以通过下面的代码。
1 | import tensorflow as tf |
得到一样的效果
1 | v1 : [1. 1. 1.] |
具体来说 .meta 对应的是 MetaGraph 和 SaverGraph,.index 对应的是变量值的位置,key 是变量名,value 是变量保存的入口定义,data 变量的值具体保存的文件。这是恢复代码中已经原样构造出了 Graph,如果没有构造的化,需要通过 tf.train.import_meta_graph('/tmp/model.ckpt.meta')
来加载,但是存档保存的信息比较单一,Tensorflow 提供了一个更丰富的 API 来使用。
保存
SavedModelBuilder
保存的 API 比较丰富,能够保存多个 MetaGraph 和 Variables 的组合,除此之外还能附带 assets,并且要指定模型签名,simple_saved
的方法是一个简单版本的调用,适用于 Predict API。这里要展开一下 GraphDef, MetaGraphDef, SignatureDef, tags 这些东西的概念。对于 MetaGraph,这篇文章解释得很清楚。SignatureDef 是对应了一种图的输入和输出,可以依据这个进行 serving API 的调用,类似于函数签名,相对于一个接口的定义。
tensorflow_serving 自己给了个例子,执行 python mnist_saved_model.py /tmp/tf-test-2
以后可以获得一个目录,下面有版本 1 的模型数据,执行 saved_model_cli show --dir /tmp/tf-test-2/1
可以查看对应的签名。可以看到对应的层级关系,默认用于服务的模型会打上 serve 的标签,函数签名有两个,分别对应了 predict 和 classify 的方法。
1 | MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: |
可以参考 tensorflow 的 REST API,比如 GET http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]
其实对应这个例子就是 GET http://host:port/v1/models/tf-test-2/versions/1
,然后感觉函数签名不同的 method name,可以调用不同的 request,比如 POST http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]:predict
这个格式,如果输入和输出对应的是 images
和 scores
那么就对应了第一个签名。
冻结
冻结的情况就是变量不再需要修改,直接把变量转化成常量保存成单一的模型,方便在部署的场景下使用。
冻结模型的代码在这里,他的主要流程如下
- 清除所有 Op 中的 device,让原来在指定 CPU/GPU/节点 上的 Op 不再绑定。
- 通过
graph_util.convert_variables_to_constants
将所有的 Variableeval
一次,把变量的 Op 的结果拿到,替换成 constant
优化
除了冻结模型以外,还可以删减一些多余的节点,比如 Summary 节点或者 Identity 节点,甚者把 16bit 的浮点数权重修改为 8bit 的浮点数权重(这个在 Tensorflow Lite 里很有用)。这篇文章 列出了详细的优化方式,主要是靠 transform_graph
这个工具,地址在这,他有很详细的柴剪列表,并且可以自己编写裁剪函数,充分做到模型在部署环节的“纯净化”,调用方式也很简单。
1 | transform_graph \ |
在transforms
里面加入你想进行优化的 transformer 和对应的参数即可,在科赛上也有在线可以跑的notebook