TensorFlow中遇到的问题及解决方法

本文为转载,记录一下在使用TensorFlow的过程中,遇到的较为困扰的问题及最终的解决方法。

Q1. 如何查看TensorFlow中Tensor, Variable, Constant的值?

TensorFlow中的许多方法返回的都是一个Tensor对象。在Debug的过程中,我们发现只能看到Tensor对象的一些属性信息,无法查看Tensor具体的输出值;而对于Variable和Constant,我们很容易对其进行创建操作,但是如何得到它们的值呢?

假设ts是我们想要查看的对象(Variable / Constant / 0输入的Tensor),运行

1
2
ts_res = sess.run(ts)
print(ts_res)
其中,sess为之前创建或默认的session. 运行后将得到一个narray格式的ts_res对象,通过print函数我们可以很方便的查看其中的内容。

但是,如果ts是一个有输入要求的Tensor,需要在查看其输出值前,填充(feed)输入数据。如下(假设ts只有一种输入):

1
2
3
input = ××××××  # the input data need to feed
ts_res = sess.run(ts, feed_dict=input)
print(ts_res)
其他要求多种输入的Tensor类似处理即可。

Q2. 模型训练完成后,如何获取模型的参数?

模型训练完成后,通常会将模型参数存储于/checkpoint/×××.model文件(当然文件路径和文件名都可以更改,许多基于TensorFlow的开源包习惯将模型参数存储为model或者model.ckpt文件)。那么,在模型训练完成后,如何得到这些模型参数呢?

需要以下两个步骤:

Step 1: 通过tf.train.Saver()恢复模型参数

运行

1
saver = tf.train.Saver()
通过saverrestore()方法可以从本地的模型文件中恢复模型参数。大致做法如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# your model's params
# you don't have to initialize them

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
W = tf.Variable(...)
b = tf.Variable(...)

y_ = tf.add(b, tf.matmul(x, w))

# create the saver

saver = tf.train.Saver()

# creat the session you used in the training processing
# launch the model

with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/your/path/model.ckpt")
print("Model restored.")
# Do some work with the model, such as do a prediction
pred = sess.run(y_, feed_dict={batch_x})
...
有关TensorFlow中变量的创建、存储及恢复操作,详细见API文档.

Step 2: 通过tf.trainable\_variables()得到训练参数

tf.trainable\_variables()方法将返回模型中所有可训练的参数,详细见API文档。类似于以下的变量参数不会被返回:

1
tf_var = tf.Variable(0, name="××××××", trainable=False)
还可以通过Variablename属性过滤出需要查看的参数,如下:
1
var = [v for v in t_vars if v.name == "W"]

本文来自:TensorFlow中遇到的问题及解决方法