基于MSELoss()与CrossEntropyLoss()的区别详解
作者:Foneone 发布时间:2022-05-17 19:18:27
标签:MSELoss,CrossEntropyLoss
基于pytorch来讲
MSELoss()多用于回归问题,也可以用于one_hotted编码形式,
CrossEntropyLoss()名字为交叉熵损失函数,不用于one_hotted编码形式
MSELoss()要求batch_x与batch_y的tensor都是FloatTensor类型
CrossEntropyLoss()要求batch_x为Float,batch_y为LongTensor类型
(1)CrossEntropyLoss() 举例说明:
比如二分类问题,最后一层输出的为2个值,比如下面的代码:
class CNN (nn.Module ) :
def __init__ ( self , hidden_size1 , output_size , dropout_p) :
super ( CNN , self ).__init__ ( )
self.hidden_size1 = hidden_size1
self.output_size = output_size
self.dropout_p = dropout_p
self.conv1 = nn.Conv1d ( 1,8,3,padding =1)
self.fc1 = nn.Linear (8*500, self.hidden_size1 )
self.out = nn.Linear (self.hidden_size1,self.output_size )
def forward ( self , encoder_outputs ) :
cnn_out = F.max_pool1d ( F.relu (self.conv1(encoder_outputs)),2)
cnn_out = F.dropout ( cnn_out ,self.dropout_p) #加一个dropout
cnn_out = cnn_out.view (-1,8*500)
output_1 = torch.tanh ( self.fc1 ( cnn_out ) )
output = self.out ( ouput_1)
return output
最后的输出结果为:
上面一个tensor为output结果,下面为target,没有使用one_hotted编码。
训练过程如下:
cnn_optimizer = torch.optim.SGD(cnn.parameters(),learning_rate,momentum=0.9,\
weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
def train ( input_variable , target_variable , cnn , cnn_optimizer , criterion ) :
cnn_output = cnn( input_variable )
print(cnn_output)
print(target_variable)
loss = criterion ( cnn_output , target_variable)
cnn_optimizer.zero_grad ()
loss.backward( )
cnn_optimizer.step( )
#print('loss: ',loss.item())
return loss.item() #返回损失
说明CrossEntropyLoss()是output两位为one_hotted编码形式,但target不是one_hotted编码形式。
(2)MSELoss() 举例说明:
网络结构不变,但是标签是one_hotted编码形式。下面的图仅做说明,网络结构不太对,出来的预测也不太对。
如果target不是one_hotted编码形式会报错,报的错误如下。
目前自己理解的两者的区别,就是这样的,至于多分类问题是不是也是样的有待考察。
来源:https://blog.csdn.net/foneone/article/details/90127707
0
投稿
猜你喜欢
- 很多人认为python中的字典是无序的,因为它是按照hash来存储的,但是python中有个模块collections(英文,收集、集合),
- 1、判断多个条件的语句,if为真则执行if后面的语句。2、如果elif是真的,则执行elif,后面的代码块不执行。3、如果if和elif不满
- 函数的返回值一个函数执行后可以返回多个返回值def measure(): print('测量开始。。。。&
- 本文实例为大家分享了基于神经卷积网络的人脸识别,供大家参考,具体内容如下1.人脸识别整体设计方案客_服交互流程图:2.服务端代码展示sk =
- 目的工作中遇到一个需求,通过需要通过网站查询船舶名称得到MMSI码,网站来自船讯网。分析请求根据以往爬虫的经验,打开F12,通过输入船舶名称
- 在python3.6版本中去掉了os.path.walk()函数os.walk()函数声明:walk(top,topdown=True,on
- python中安装包的方式有很多种:源码包:python setup.py install在线安装:pip install 包名(linux
- 概述在实践中,我们发现上述的代码重复率非常高,新增和修改都费力,并且是没技术含量的体力活。 但又必须要这样做,不适合以公共函数的形式重用,为
- 环境 python3.8pycharm2021.2知识点requests >>> pip install req
- 1,不用第三方库# coding: utf-8import loggingBLACK, RED, GREEN, YELLOW, BLUE,
- 淘宝的 NPM 镜像是一个完整的npmjs.org镜像。你可以用此代替官方版本(只读),同步频率目前为 15分钟 一次以保证尽量与官方服务同
- 本文实例总结了Python操作redis方法。分享给大家供大家参考,具体如下:python连接方式可参考:https://www.jb51.
- 先看一下效果图: index.wxml <view class='{{tabIsTop ? "fixedT
- 前言上一篇介绍了客户端流式RPC,客户端不断的向服务端发送数据流,在发送结束或流关闭后,由服务端返回一个响应。本篇将介绍双向流式RPC。双向
- 如下所示:import matplotlib.pyplot as pltimport numpy as npdef readfile(fil
- 实际操作中我们经常需要寻找数据的某行或者某列,这里介绍我在使用Pandas时用到的两种方法:iloc和loc。loc:通过行、列的名称或标签
- 1:创建用户 create temporary tablespace user_temp tempfile 'D:\app\topw
- 比如有下面一段代码: for i in range(10): print ("%s" % (f_list[i].name
- 在list中嵌套元组,在进行sort排序的时候,产生的是原数组的副本,排序过程中,先根据第一个字段进行从小到大排序,如果第一个字段相同的话,
- 这一次将使用pymysql来进行一次对MySQL的增删改查的全部操作,相当于对前五次的总结:先查阅数据库:现在编写源码进行增删改查操作,源码