pytorch交叉熵损失函数的weight参数的使用

作者:Nick Blog 时间:2021-02-27 15:52:31 

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。


class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码


import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)

tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?


import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)

tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:


import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)

tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

来源:https://niecongchong.blog.csdn.net/article/details/86594621

标签:pytorch,交叉熵损失,weight
0
投稿

猜你喜欢

  • vue3+ts如何通过lodash实现防抖节流详解

    2024-05-02 16:32:13
  • 微信小程序与php 实现微信支付的简单实例

    2023-11-14 15:22:07
  • 使用sysbench来测试MySQL性能的详细教程

    2024-01-14 14:33:54
  • python训练数据时打乱训练数据与标签的两种方法小结

    2021-11-17 11:15:12
  • python导入pandas具体步骤方法

    2022-08-14 16:01:03
  • pygame游戏之旅 创建游戏窗口界面

    2022-05-19 18:42:00
  • 在python里协程使用同步锁Lock的实例

    2022-07-31 14:26:04
  • 说说CSS的优先权 考虑CSS的继承与层叠

    2008-12-11 13:33:00
  • Mac OS下PHP环境搭建及PHP操作MySQL常用方法小结

    2024-05-08 10:16:31
  • win7系统安装SQLServer2000的详细步骤(图文)

    2024-01-27 02:39:26
  • php注册和登录界面的实现案例(推荐)

    2024-04-30 08:48:47
  • Python实现学生管理系统并生成exe可执行文件详解流程

    2023-03-11 04:52:42
  • .Net Core服务治理Consul使用服务发现

    2023-06-25 07:49:19
  • 利用python爬取软考试题之ip自动代理

    2023-01-30 01:17:28
  • javascript实现右下角广告框效果

    2024-04-17 10:25:08
  • Python3列表List入门知识附实例

    2023-03-12 06:41:22
  • 深度学习TextRNN的tensorflow1.14实现示例

    2023-12-31 18:59:23
  • 数据库自动化技术弥补数据库DBA短缺难题

    2009-02-04 16:53:00
  • Python3基础之list列表实例解析

    2022-04-22 16:07:15
  • python遍历迭代器自动链式处理数据的实例代码

    2022-04-12 18:38:29
  • asp之家 网络编程 m.aspxhome.com