聊聊 TensorFlow 的 build_graph 问题
发现问题
前一阵子在调一个文本分类的模型,调试的时候发现一个奇特的报错,大概如下:
1 | ValueError: Variable layer1/weights1 already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at: |
问题原因
后面打断点 Debug 发现问题出在重复使用 TensorFlow 的 build_graph
方法上。那么为什么不能重复使用这个方法呢?
说到这个,可能要稍微深入地理解下 TF 的工作原理了。TF 内置了一个隐藏的状态,用于存储建立好的图。一旦调用了 build_graph
方法后,所有之前通过 TF 创建的变量和 Op
都会被用于构建图并存储起来。
实际的运算依赖之前建立的图,并发生成 tf.Session()
中。如果有面向对象的开发经验,可以这样理解:
build_graph
相当于把之前建图的操作生成一个类,而这个类一旦生成就封版的,不能修改tf.Session()
则是创建了运行环境,并实例化出一个图的对象,进行模型训练和预测
一般来说,我们使用 TF 建图和训练时用到的代码结构是这样的:
1 | build_graph() |
因为实际的变量、训练结果都局限在特定的「运行环境」中,所以我们也通过下面的代码,在两个不同的 Session
中以不同的参数训练:
1 | build_graph() |
那么为什么我会遇到多次 build_graph
这样的场景和报错呢?我之前写了一个框架,分离了模型的训练和评估代码,模型的训练和评估是可以正常单独运行的。
而在某种场合下,我需要在同一脚本中,依次调用模型的训练和评估,这个时候就出现问题了。大概代码执行顺序如下:
1 | build_graph() |
那么是不是完全无解了呢?当然也不是!可以在第二次 build_graph
之前,使用 TF 的 reset_default_graph
方法,重置已经建的图。
1 | build_graph() |
TF 官方文档中有提到该方法的使用误区:
Calling this function while a tf.Session or tf.InteractiveSession is active will result in undefined behavior. Using any previously created tf.Operation or tf.Tensor objects after calling this function will result in undefined behavior
所以,reset_default_graph
也不要乱用,比如像下面这种使用方法就是错的。
1 | import tensorflow as tf |
参考文献
Error while running tensorflow a second time
How to use tf.reset_default_graph()