关于tensorflow softmax函数用法解析

作者:ASR_THU 时间:2022-10-29 07:42:09 

如下所示:


def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
This function performs the equivalent of
 softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
 `float32`, `float64`.
axis: The dimension softmax would be performed on. The default is -1 which
 indicates the last dimension.
name: A name for the operation (optional).
dim: Deprecated alias for `axis`.
Returns:
A `Tensor`. Has the same type and shape as `logits`.
Raises:
InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
 dimension of `logits`.
"""
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
if axis is None:
axis = -1
return _softmax(logits, gen_nn_ops.softmax, axis, name)

softmax函数的返回结果和输入的tensor有相同的shape,既然没有改变tensor的形状,那么softmax究竟对tensor做了什么?

答案就是softmax会以某一个轴的下标为索引,对这一轴上其他维度的值进行 激活 + 归一化处理

一般来说,这个索引轴都是表示类别的那个维度(tf.nn.softmax中默认为axis=-1,也就是最后一个维度)

举例:


def softmax(X, theta = 1.0, axis = None):
"""
Compute the softmax of each element along an axis of X.
Parameters
----------
X: ND-Array. Probably should be floats.
theta (optional): float parameter, used as a multiplier
 prior to exponentiation. Default = 1.0
axis (optional): axis to compute values along. Default is the
 first non-singleton axis.
Returns an array the same size as X. The result will sum to 1
along the specified axis.
"""

# make X at least 2d
y = np.atleast_2d(X)

# find axis
if axis is None:
 axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)

# multiply y against the theta parameter,
y = y * float(theta)

# subtract the max for numerical stability
y = y - np.expand_dims(np.max(y, axis = axis), axis)

# exponentiate y
y = np.exp(y)

# take the sum along the specified axis
ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)

# finally: divide elementwise
p = y / ax_sum

# flatten if X was 1D
if len(X.shape) == 1: p = p.flatten()

return p
c = np.random.randn(2,3)
print(c)
# 假设第0维是类别,一共有里两种类别
cc = softmax(c,axis=0)
# 假设最后一维是类别,一共有3种类别
ccc = softmax(c,axis=-1)
print(cc)
print(ccc)

结果:


c:
[[-1.30022268 0.59127472 1.21384177]
[ 0.1981082 -0.83686108 -1.54785864]]
cc:
[[0.1826746 0.80661068 0.94057075]
[0.8173254 0.19338932 0.05942925]]
ccc:
[[0.0500392 0.33172426 0.61823654]
[0.65371718 0.23222472 0.1140581 ]]

可以看到,对axis=0的轴做softmax时,输出结果在axis=0轴上和为1(eg: 0.1826746+0.8173254),同理在axis=1轴上做的话结果的axis=1轴和也为1(eg: 0.0500392+0.33172426+0.61823654)。

这些值是怎么得到的呢?

以cc为例(沿着axis=0做softmax):

关于tensorflow softmax函数用法解析

以ccc为例(沿着axis=1做softmax):

关于tensorflow softmax函数用法解析

知道了计算方法,现在我们再来讨论一下这些值的实际意义:

cc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.1826746

cc[1,0]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.8173254

ccc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.0500392

ccc[0,1]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.33172426

ccc[0,2]实际上表示这样一种概率: P( label = 2 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.61823654

将他们扩展到更多维的情况:假设c是一个[batch_size , timesteps, categories]的三维tensor

output = tf.nn.softmax(c,axis=-1)

那么 output[1, 2, 3] 则表示 P(label =3 | value = c[1,2] )

来源:https://blog.csdn.net/zongza/article/details/88016668

标签:tensorflow,softmax
0
投稿

猜你喜欢

  • Python入门之列表用法详解

    2023-10-04 05:44:50
  • 利用Python统计Jira数据并可视化

    2023-06-26 00:11:46
  • python3 使用openpyxl将mysql数据写入xlsx的操作

    2024-01-25 14:58:23
  • pycharm 实现光标快速移动到括号外或行尾的操作

    2023-07-17 19:52:31
  • python实现整数的二进制循环移位

    2022-09-08 23:11:39
  • javascript中类的创建和继承

    2008-05-08 12:07:00
  • Python中Yield的基本用法

    2021-08-30 15:34:55
  • Python迭代器Iterable判断方法解析

    2023-06-11 15:37:19
  • eWebEditor不支持IE,IE8,IE7,火狐,遨游的解决方法

    2011-06-06 07:57:00
  • Python3.7基于hashlib和Crypto实现加签验签功能(实例代码)

    2023-03-25 16:23:00
  • php判断正常访问和外部访问的示例

    2024-05-11 09:45:46
  • SqlServer创建自动收缩事务日志任务的图文教程

    2024-01-23 23:44:40
  • MYSQL METADATA LOCK(MDL LOCK)MDL锁问题分析

    2024-01-15 04:45:54
  • sql与mysql有哪些区别

    2024-01-23 16:21:16
  • pytorch 多个反向传播操作

    2021-03-05 02:21:49
  • Python实现对照片中的人脸进行颜值预测

    2023-04-05 18:21:12
  • SQL Server数据库查询优化3种技巧

    2008-10-17 10:10:00
  • Python数据可视化之画图

    2023-11-04 11:09:15
  • 在 Python 中接管键盘中断信号的实现方法

    2022-11-25 15:14:55
  • Oracle9iPL/SQL编程的经验小结

    2024-01-17 11:12:01
  • asp之家 网络编程 m.aspxhome.com