tensorflow 模型的存档、保存、冻结、优化

模型的保存分三种类型

  1. 知道模型结构,单纯保存变量
  2. 不知道模型结构,保存模型和变量
  3. 不需要再改变量,只要常量化的模型(“冻结”)

第一种用于训练的存档,并且临时恢复,这个时候用户是把训练需要的网络结构在代码里面构造好了的,只是在一定的时间下需要暂时保存网络中的变量,为了在崩溃之后继续训练。所以自然而然会有一个问题,如果我用 Python 写的代码,需要在 C++ 当中恢复,我需要知道你的模型结构,才能恢复,这个最蠢的办法是用 C++ 把你的网络结构再构造一遍,但我们按照统一的协议(比如 Protobuf)确定网络结构,就可以直接从标准序列化的数据中解析网络结构,这就是第二种情况,独立于语言,模型和变量一起保存的情况。然后如果碰到我们不需要再训练了,比如只是把这个模型进行部署,不需要改变相关的变量,那么其实只要一个带常量的模型就可以,这就是第三种情况,把变量冻结的正向传播模型。接下来会依次解释这几种情况的工作方式。

除了这些以外,针对用于服务的模型还可以做很多的优化。

存档

存档只是单纯的保存变量,并且能够恢复,可以在一定的迭代次数以后保存变量,并且从任意一个存档开始重新训练。以两个变量加减 1 为例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/tf-test/model.ckpt")
print("Model saved in path: %s" % save_path)

可以在 /tmp/tf-test 下面看到这几个文件 checkpoint model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta

可以通过脚本观察保存的变量 python $tensorflow-src/tensorflow/python/tools/inspect_checkpoint.py --file_name=/tmp/tf-test/model.ckpt --all_tensors

得到保存的变量的内容,注意 model.ckpt 这个只是文件前缀。

1
2
3
4
tensor_name: v1
[1. 1. 1.]
tensor_name: v2
[-1. -1. -1. -1. -1.]

如果要恢复的话,可以通过下面的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/tf-test/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())

得到一样的效果

1
2
v1 : [1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]

具体来说 .meta 对应的是 MetaGraph 和 SaverGraph,.index 对应的是变量值的位置,key 是变量名,value 是变量保存的入口定义,data 变量的值具体保存的文件。这是恢复代码中已经原样构造出了 Graph,如果没有构造的化,需要通过 tf.train.import_meta_graph('/tmp/model.ckpt.meta') 来加载,但是存档保存的信息比较单一,Tensorflow 提供了一个更丰富的 API 来使用。

保存

SavedModelBuilder 保存的 API 比较丰富,能够保存多个 MetaGraph 和 Variables 的组合,除此之外还能附带 assets,并且要指定模型签名,simple_saved 的方法是一个简单版本的调用,适用于 Predict API。这里要展开一下 GraphDef, MetaGraphDef, SignatureDef, tags 这些东西的概念。对于 MetaGraph,这篇文章解释得很清楚。SignatureDef 是对应了一种图的输入和输出,可以依据这个进行 serving API 的调用,类似于函数签名,相对于一个接口的定义。

tensorflow_serving 自己给了个例子,执行 python mnist_saved_model.py /tmp/tf-test-2 以后可以获得一个目录,下面有版本 1 的模型数据,执行 saved_model_cli show --dir /tmp/tf-test-2/1 可以查看对应的签名。可以看到对应的层级关系,默认用于服务的模型会打上 serve 的标签,函数签名有两个,分别对应了 predict 和 classify 的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['predict_images']:
The given SavedModel SignatureDef contains the following input(s):
inputs['images'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 784)
name: x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 10)
name: y:0
Method name is: tensorflow/serving/predict
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: unknown_rank
name: tf_example:0
The given SavedModel SignatureDef contains the following output(s):
outputs['classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 10)
name: index_to_string_Lookup:0
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 10)
name: TopKV2:0
Method name is: tensorflow/serving/classify

可以参考 tensorflow 的 REST API,比如 GET http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}] 其实对应这个例子就是 GET http://host:port/v1/models/tf-test-2/versions/1,然后感觉函数签名不同的 method name,可以调用不同的 request,比如 POST http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]:predict 这个格式,如果输入和输出对应的是 imagesscores 那么就对应了第一个签名。

冻结

冻结的情况就是变量不再需要修改,直接把变量转化成常量保存成单一的模型,方便在部署的场景下使用。
冻结模型的代码在这里,他的主要流程如下

  1. 清除所有 Op 中的 device,让原来在指定 CPU/GPU/节点 上的 Op 不再绑定。
  2. 通过 graph_util.convert_variables_to_constants 将所有的 Variable eval 一次,把变量的 Op 的结果拿到,替换成 constant

优化

除了冻结模型以外,还可以删减一些多余的节点,比如 Summary 节点或者 Identity 节点,甚者把 16bit 的浮点数权重修改为 8bit 的浮点数权重(这个在 Tensorflow Lite 里很有用)。这篇文章 列出了详细的优化方式,主要是靠 transform_graph 这个工具,地址在,他有很详细的柴剪列表,并且可以自己编写裁剪函数,充分做到模型在部署环节的“纯净化”,调用方式也很简单。

1
2
3
4
5
6
7
8
9
10
transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_old_batch_norms
'

transforms里面加入你想进行优化的 transformer 和对应的参数即可,在科赛上也有在线可以跑的notebook

参考

  1. Cloud ML Engine
Donate comment here