pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
作者:aift 时间:2021-06-03 09:28:09
公式
首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:
其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。
测试代码(一维)
import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)
print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)
first = 0
for i in range(1):
first = -output[i][label[i]]
second = 0
for i in range(1):
for j in range(5):
second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)
测试代码(多维)
import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)
print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)
first = [0, 0, 0]
for i in range(3):
first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
for j in range(5):
second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)
nn.CrossEntropyLoss()中的计算方法
注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。
在pytorch中,CrossEntropyLoss计算公式为:
CrossEntropyLoss带权重的计算公式为(默认weight=None):
来源:https://blog.csdn.net/ft_sunshine/article/details/92074842
标签:pytorch,交叉熵损失,nn.CrossEntropyLoss
0
投稿
猜你喜欢
Django 添加静态文件的两种实现方法(必看篇)
2021-09-03 23:53:58
Java中使用正则表达式的一个简单例子及常用正则分享
2023-05-06 09:03:16
goland -sync/atomic原子操作小结
2024-04-26 17:20:08
Python进度条的制作代码实例
2022-01-01 23:17:34
Python安装图文教程 Pycharm安装教程
2022-06-19 20:03:05
如何用C代码给Python写扩展库(Cython)
2023-06-08 17:06:32
C#基于数据库存储过程的AJAX分页实例
2024-01-26 20:43:23
Python中的字符串切片(截取字符串)的详解
2023-07-23 20:37:59
Win10系统提示“Python 0x80070643安装时发生严重错误”怎么办?
2023-06-13 06:50:25
vue项目使用node连接数据库的方法(前后端分离)
2024-01-27 14:08:54
python实现神经网络感知器算法
2021-03-06 11:23:39
Python实现判断并移除列表指定位置元素的方法
2023-03-21 03:06:19
如何理解PHP程序执行的过程原理
2023-10-08 14:45:10
Node.js readline模块与util模块的使用
2024-05-11 10:13:00
PyTorch中常用的激活函数的方法示例
2022-11-02 01:17:44
Python可视化最频繁使用的10大工具总结
2022-07-01 04:00:17
SQL服务器无法启动的解决方法
2024-01-16 04:47:20
Python中pyecharts安装及安装失败的解决方法
2021-01-13 06:00:52
利用PyInstaller将python程序.py转为.exe的方法详解
2021-07-09 16:41:51
python中操作文件的模块的方法总结
2022-02-01 04:56:14