pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

作者:aift 时间:2021-06-03 09:28:09 

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)


import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
 first = -output[i][label[i]]
second = 0
for i in range(1):
 for j in range(5):
   second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)


import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
 first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
 for j in range(5):
   second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
 res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

来源:https://blog.csdn.net/ft_sunshine/article/details/92074842

标签:pytorch,交叉熵损失,nn.CrossEntropyLoss
0
投稿

猜你喜欢

  • Django 添加静态文件的两种实现方法(必看篇)

    2021-09-03 23:53:58
  • Java中使用正则表达式的一个简单例子及常用正则分享

    2023-05-06 09:03:16
  • goland -sync/atomic原子操作小结

    2024-04-26 17:20:08
  • Python进度条的制作代码实例

    2022-01-01 23:17:34
  • Python安装图文教程 Pycharm安装教程

    2022-06-19 20:03:05
  • 如何用C代码给Python写扩展库(Cython)

    2023-06-08 17:06:32
  • C#基于数据库存储过程的AJAX分页实例

    2024-01-26 20:43:23
  • Python中的字符串切片(截取字符串)的详解

    2023-07-23 20:37:59
  • Win10系统提示“Python 0x80070643安装时发生严重错误”怎么办?

    2023-06-13 06:50:25
  • vue项目使用node连接数据库的方法(前后端分离)

    2024-01-27 14:08:54
  • python实现神经网络感知器算法

    2021-03-06 11:23:39
  • Python实现判断并移除列表指定位置元素的方法

    2023-03-21 03:06:19
  • 如何理解PHP程序执行的过程原理

    2023-10-08 14:45:10
  • Node.js readline模块与util模块的使用

    2024-05-11 10:13:00
  • PyTorch中常用的激活函数的方法示例

    2022-11-02 01:17:44
  • Python可视化最频繁使用的10大工具总结

    2022-07-01 04:00:17
  • SQL服务器无法启动的解决方法

    2024-01-16 04:47:20
  • Python中pyecharts安装及安装失败的解决方法

    2021-01-13 06:00:52
  • 利用PyInstaller将python程序.py转为.exe的方法详解

    2021-07-09 16:41:51
  • python中操作文件的模块的方法总结

    2022-02-01 04:56:14
  • asp之家 网络编程 m.aspxhome.com