Tensorflow中的dropout的使用方法

作者:AGUILLER 时间:2021-04-22 01:21:58 

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout


def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout


def dropout(inputs,
  rate=0.5,
  noise_shape=None,
  seed=None,
  training=False,
  name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。


def sparse_dropout(x, keep_prob, noise_shape):
keep_tensor = keep_prob + tf.random_uniform(noise_shape)
drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
out = tf.sparse_retain(x, drop_mask)
return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例


def nn_dropout(x, keep_prob, noise_shape):
out = tf.nn.dropout(x, keep_prob, noise_shape)
return out

def layers_dropout(x, keep_prob, noise_shape, training=False):
out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
return out

def sparse_dropout(x, keep_prob, noise_shape):
keep_tensor = keep_prob + tf.random_uniform(noise_shape)
drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
out = tf.sparse_retain(x, drop_mask)
return out * (1.0/keep_prob)

if __name__ == '__main__':
inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
inputs2 = tf.sparse_tensor_to_dense(inputs1)
nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 (in1, in2) = sess.run([inputs1, inputs2])
 print(in1)
 print(in2)
 (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
 print(out1)
 print(out2)
 print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,


SparseTensorValue(indices=array([[0, 0],
 [0, 2],
 [1, 1],
 [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
[ 0. 3. 4.]]

[[ 2. 0. 0.]
[ 0. 0. 0.]]
[[ 0. 0. 4.]
[ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

来源:https://segmentfault.com/a/1190000021997935

标签:Tensorflow,dropout
0
投稿

猜你喜欢

  • python使用pip安装SciPy、SymPy、matplotlib教程

    2022-03-05 01:46:12
  • 使用Python横向合并excel文件的实例

    2023-09-19 21:20:18
  • MSSQL段落还原脚本,SQLSERVER段落脚本

    2024-01-22 14:48:15
  • perl获取日期与时间的实例代码

    2023-03-30 23:57:01
  • MySQL数据库常用命令小结

    2024-01-15 22:16:31
  • Python selenium的基本使用方法分析

    2021-11-04 10:01:30
  • python通过colorama模块在控制台输出彩色文字的方法

    2023-07-23 00:35:51
  • python -v 报错问题的解决方法

    2022-04-03 03:07:29
  • 八个有用的WordPress的SQL语句

    2009-01-12 18:54:00
  • vue 集成jTopo 处理方法

    2024-05-09 15:17:42
  • Django中的ajax请求

    2022-10-19 10:28:14
  • python文件及目录操作代码汇总

    2022-08-19 14:07:27
  • mysql binlog二进制日志详解

    2024-01-19 09:09:56
  • python环境路径配置以及命令行运行脚本

    2023-09-19 21:19:05
  • Frontpage轻松下载网页或站点

    2007-10-22 13:14:00
  • CentOS 7下安装Python 3.5并与Python2.7兼容并存详解

    2021-09-18 03:03:32
  • 配置 Pycharm 默认 Test runner 的图文教程

    2023-12-06 09:03:32
  • python将秒数转化为时间格式的实例

    2023-09-24 12:10:22
  • Python栈算法的实现与简单应用示例

    2023-11-16 23:18:30
  • 利用Python实现颜色色值转换的小工具

    2021-09-12 07:44:15
  • asp之家 网络编程 m.aspxhome.com