pytorch交叉熵损失函数的weight参数的使用

作者:Nick Blog 时间:2021-02-27 15:52:31 

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。


class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码


import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)

tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?


import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)

tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:


import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)

tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

来源:https://niecongchong.blog.csdn.net/article/details/86594621

标签:pytorch,交叉熵损失,weight
0
投稿

猜你喜欢

  • 可以让程序告诉我详细的页面错误和数据库连接错误吗?

    2009-11-01 18:01:00
  • Python浅析迭代器Iterator的使用

    2023-11-07 12:04:25
  • Python3中FuzzyWuzzy库实例用法

    2022-01-30 18:49:49
  • asp开发中textarea常见问题

    2008-04-13 06:34:00
  • thinkphp5加layui实现图片上传功能(带图片预览)

    2023-06-13 01:09:45
  • Python开发毕设案例之桌面学生信息管理程序

    2021-03-02 14:56:08
  • Python使用grequests并发发送请求的示例

    2022-11-08 15:38:01
  • 详解一种用django_cache实现分布式锁的方式

    2023-11-08 03:50:45
  • Python实现KNN(K-近邻)算法的示例代码

    2023-09-25 15:56:18
  • Golang正整数指定规则排序算法问题分析

    2023-07-12 09:12:03
  • php中ob_flush函数和flush函数用法分析

    2023-11-15 06:12:59
  • JavaScript图片放大镜效果

    2009-10-19 22:15:00
  • 深度学习入门之Pytorch 数据增强的实现

    2021-04-05 22:26:07
  • Python多线程编程(二):启动线程的两种方法

    2023-11-27 16:15:48
  • pytest解读一次请求多个fixtures及多次请求

    2023-07-20 01:13:43
  • Python还能这么玩之用Python修改了班花的开机密码

    2023-11-23 17:38:40
  • php基础教程 php内置函数实例教程

    2023-11-14 18:28:45
  • jQuery 1.3的VS智能提示下载

    2009-01-18 12:54:00
  • 如何使用FSO修改文件夹的名称

    2008-10-11 14:24:00
  • Python的字符串操作简单实例

    2021-03-13 07:16:00
  • asp之家 网络编程 m.aspxhome.com