TensorFlow中遇到的问题及解决方法
本文为转载,记录一下在使用TensorFlow的过程中,遇到的较为困扰的问题及最终的解决方法。
Q1. 如何查看TensorFlow中Tensor, Variable, Constant的值?
TensorFlow中的许多方法返回的都是一个Tensor对象。在Debug的过程中,我们发现只能看到Tensor对象的一些属性信息,无法查看Tensor具体的输出值;而对于Variable和Constant,我们很容易对其进行创建操作,但是如何得到它们的值呢?
假设ts
是我们想要查看的对象(Variable / Constant / 0输入的Tensor),运行 1
2ts_res = sess.run(ts)
print(ts_res)sess
为之前创建或默认的session
. 运行后将得到一个narray
格式的ts_res
对象,通过print
函数我们可以很方便的查看其中的内容。
但是,如果ts
是一个有输入要求的Tensor,需要在查看其输出值前,填充(feed)输入数据。如下(假设ts只有一种输入): 1
2
3input = ×××××× # the input data need to feed
ts_res = sess.run(ts, feed_dict=input)
print(ts_res)Tensor
类似处理即可。
Q2. 模型训练完成后,如何获取模型的参数?
模型训练完成后,通常会将模型参数存储于/checkpoint/×××.model
文件(当然文件路径和文件名都可以更改,许多基于TensorFlow的开源包习惯将模型参数存储为model或者model.ckpt文件)。那么,在模型训练完成后,如何得到这些模型参数呢?
需要以下两个步骤:
Step 1: 通过tf.train.Saver()恢复模型参数
运行 1
saver = tf.train.Saver()
saver
的restore()
方法可以从本地的模型文件中恢复模型参数。大致做法如下: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24# your model's params
# you don't have to initialize them
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
W = tf.Variable(...)
b = tf.Variable(...)
y_ = tf.add(b, tf.matmul(x, w))
# create the saver
saver = tf.train.Saver()
# creat the session you used in the training processing
# launch the model
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/your/path/model.ckpt")
print("Model restored.")
# Do some work with the model, such as do a prediction
pred = sess.run(y_, feed_dict={batch_x})
...
Step 2: 通过tf.trainable\_variables()
得到训练参数
tf.trainable\_variables()
方法将返回模型中所有可训练的参数,详细见API文档。类似于以下的变量参数不会被返回: 1
tf_var = tf.Variable(0, name="××××××", trainable=False)
Variable
的name
属性过滤出需要查看的参数,如下: 1
var = [v for v in t_vars if v.name == "W"]