Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]

Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]在使用tensorflow搭建模型时,需要定义许多变量,例如一个映射层就需要权重与偏置。当网络结果越来越复杂,变量越来越多的时候,就需要一个查看管理变量的函数,在tensorflow中,tf.trainable_variables(),tf.all_variables(),和tf.global_variables()可以来满足查看变量的要求,来简单说一下他们的不同。_tf.global_variables()

tf.trainable_variables(), tf.all_variables(), tf.global_variables()查看变量

在使用tensorflow搭建模型时,需要定义许多变量,例如一个映射层就需要权重与偏置。当网络结果越来越复杂,变量越来越多的时候,就需要一个查看管理变量的函数,在tensorflow中,tf.trainable_variables(), tf.all_variables(),和tf.global_variables()可以来满足查看变量的要求,来简单说一下他们的不同。

tf.trainable_variables()

顾名思义,这个函数可以也仅可以查看可训练的变量,在我们生成变量时,无论是使用tf.Variable()还是tf.get_variable()生成变量,都会涉及一个参数trainable,其默认为True。以tf.Variable()为例:

__init__(
    initial_value=None,
    trainable=True,
    collections=None,
    validate_shape=True,
   ...
)

对于一些我们不需要训练的变量,比较典型的例如学习率或者计步器这些变量,我们都需要将trainable设置为False,这时tf.trainable_variables() 就不会打印这些变量。举个简单的例子,在下图中共定义了4个变量,分别是一个权重矩阵,一个偏置向量,一个学习率和计步器,其中前两项是需要训练的而后两项则不需要。


Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
这个时候tf.trainable_variables()效果如下:

Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
另一个问题就是,如果变量定义在scope域中,是否会有不同。实际上,tf.trainable_variables()是可以通过参数选定域名的,如下图所示:

Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
我们重新声明了两个新变量,其中w2是在‘var’中的,如果我们直接使用tf.trainable_variables(),结果如下:

Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
但如果我们只希望查看‘var’域中的变量,我们可以通过加入scope参数的方式实现:

Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
可以看到,只有w2被打印出来。

tf.global_variables()

回到第一个例子,如果我希望查看全部变量,包括我的学习率等信息,可以通过tf.global_variables()来实现。效果如下:


Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
可以看到,这时候打印出来了4个变量,其中后两个即为trainable=False的学习率和计步器。与tf.trainable_variables()一样,tf.global_variables()也可以通过scope的参数来选定域中的变量。

tf.all_variables()

与tf.global_variables()作用拥有相似的功能,只是版本问题,可以看到:


Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
运行时会有warning的提示。还有一点需要注意的是,tf.all_variables()似乎是没有scope输入参数的,这点作用性不如前两个那么强。

应用中

在实际代码中,我们可以在定义model的时候,定义一个内部函数用来查看模型中的变量,在训练过程中,可以在开始的时候调用一次,来看一下变量名称及其阶数,对模型控制性更强,了解更加明确。


Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]
今天的文章Tensorflow小技巧整理:tf.trainable_variables(), tf.all_variables(), tf.global_variables()的使用[通俗易懂]分享到此就结束了,感谢您的阅读,如果确实帮到您,您可以动动手指转发给其他人。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/57802.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注