Python实现softmax反向传播的示例代码

作者:SugerOO 时间:2021-02-24 10:54:02 

概念

softmax函数是常用的输出层函数,常用来解决互斥标签的多分类问题。当然由于他是非线性函数,也可以作为隐藏层函数使用

反向传播求导

可以看到,softmax 计算了多个神经元的输入,在反向传播求导时,需要考虑对不同神经元的参数求导。

分两种情况考虑:

  • 当求导的参数位于分子时

  • 当求导的参数位于分母时

Python实现softmax反向传播的示例代码

当求导的参数位于分子时:

Python实现softmax反向传播的示例代码

当求导的参数位于分母时(ez2 or ez3这两个是对称的,求导结果是一样的):

Python实现softmax反向传播的示例代码

Python实现softmax反向传播的示例代码

代码

import torch
import math

def my_softmax(features):
   _sum = 0
   for i in features:
       _sum += math.e ** i
   return torch.Tensor([ math.e ** i / _sum for i in features ])

def my_softmax_grad(outputs):    
   n = len(outputs)
   grad = []
   for i in range(n):
       temp = []
       for j in range(n):
           if i == j:
               temp.append(outputs[i] * (1- outputs[i]))
           else:
               temp.append(-outputs[j] * outputs[i])
       grad.append(torch.Tensor(temp))
   return grad

if __name__ == '__main__':

features = torch.randn(10)
   features.requires_grad_()

torch_softmax = torch.nn.functional.softmax
   p1 = torch_softmax(features,dim=0)
   p2 = my_softmax(features)
   print(torch.allclose(p1,p2))

n = len(p1)
   p2_grad = my_softmax_grad(p2)
   for i in range(n):
       p1_grad = torch.autograd.grad(p1[i],features, retain_graph=True)
       print(torch.allclose(p1_grad[0], p2_grad[i]))

来源:https://blog.csdn.net/SugerOO/article/details/130032515

标签:Python,softmax,反向传播
0
投稿

猜你喜欢

  • PHP未登录自动跳转到登录页面

    2023-11-15 07:39:11
  • Python调用飞书发送消息的示例

    2022-10-20 14:21:23
  • Oracle数据库SQL语句性能调整的基本原则

    2009-03-25 16:55:00
  • Python获取航线信息并且制作成图的讲解

    2023-08-28 18:18:56
  • 使用ASP遍历并列表显示目录文件

    2009-11-08 18:32:00
  • javascript trim、left、right等函数,兼容IE,FireFox

    2009-09-18 14:55:00
  • 三条asp语句搞定路径

    2007-10-22 13:30:00
  • 用javascript做拖动层布局的思路

    2008-05-30 13:38:00
  • Python快速从注释生成文档的方法

    2022-07-11 04:55:37
  • 如何使用数组来显示下拉菜单?

    2010-05-16 15:19:00
  • Python序列的推导式实现代码

    2022-04-24 05:53:46
  • Pytorch mask-rcnn 实现细节分享

    2021-10-20 01:31:38
  • IE7下 filter:Alpha(opacity=xx) 的小问题

    2008-12-02 16:24:00
  • 什么是SVG(可升级矢量图形)

    2008-05-06 12:37:00
  • php浅析反序列化结构

    2023-11-17 17:34:37
  • python利用有道翻译实现"语言翻译器"的功能实例

    2021-08-21 02:47:38
  • my sql存储过程学习总结

    2011-07-12 19:12:35
  • Go语言包和包管理详解

    2023-07-21 15:51:03
  • 详细讲解SQL Server数据库的文件恢复技术

    2009-01-15 12:54:00
  • 恢复master..xp_logattach(log explorer)

    2010-07-01 19:19:00
  • asp之家 网络编程 m.aspxhome.com