聊聊 TensorFlow 的 build_graph 问题

发现问题

前一阵子在调一个文本分类的模型,调试的时候发现一个奇特的报错,大概如下:

1
ValueError: Variable layer1/weights1 already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

问题原因

后面打断点 Debug 发现问题出在重复使用 TensorFlow 的 build_graph 方法上。那么为什么不能重复使用这个方法呢?

说到这个,可能要稍微深入地理解下 TF 的工作原理了。TF 内置了一个隐藏的状态,用于存储建立好的图。一旦调用了 build_graph 方法后,所有之前通过 TF 创建的变量和 Op 都会被用于构建图并存储起来。

实际的运算依赖之前建立的图,并发生成 tf.Session() 中。如果有面向对象的开发经验,可以这样理解:

  • build_graph 相当于把之前建图的操作生成一个类,而这个类一旦生成就封版的,不能修改
  • tf.Session() 则是创建了运行环境,并实例化出一个图的对象,进行模型训练和预测

一般来说,我们使用 TF 建图和训练时用到的代码结构是这样的:

1
2
3
4
build_graph()

with tf.Session() as sess:
process_something()

因为实际的变量、训练结果都局限在特定的「运行环境」中,所以我们也通过下面的代码,在两个不同的 Session 中以不同的参数训练:

1
2
3
4
5
6
7
build_graph()

with tf.Session() as sess:
process_something()

with tf.Session() as sess:
process_something()

那么为什么我会遇到多次 build_graph 这样的场景和报错呢?我之前写了一个框架,分离了模型的训练和评估代码,模型的训练和评估是可以正常单独运行的。

而在某种场合下,我需要在同一脚本中,依次调用模型的训练和评估,这个时候就出现问题了。大概代码执行顺序如下:

1
2
3
4
5
6
7
8
9
build_graph()

with tf.Session() as sess:
process_something()

build_graph()

with tf.Session() as sess:
process_something()

那么是不是完全无解了呢?当然也不是!可以在第二次 build_graph 之前,使用 TF 的 reset_default_graph 方法,重置已经建的图。

1
2
3
4
5
6
7
8
9
10
11
build_graph()

with tf.Session() as sess:
process_something()

tf.reset_default_graph()

build_graph()

with tf.Session() as sess:
process_something()

TF 官方文档中有提到该方法的使用误区:

Calling this function while a tf.Session or tf.InteractiveSession is active will result in undefined behavior. Using any previously created tf.Operation or tf.Tensor objects after calling this function will result in undefined behavior

所以,reset_default_graph 也不要乱用,比如像下面这种使用方法就是错的。

1
2
3
4
import tensorflow as tf
a = tf.constant(1)
with tf.Session() as sess:
tf.reset_default_graph()

参考文献

  • Error while running tensorflow a second time

  • How to use tf.reset_default_graph()