pytorch自定义不可导激活函数的操作

作者:Luna_Lovegood_001 时间:2022-07-05 10:09:13 

pytorch自定义不可导激活函数

今天自定义不可导函数的时候遇到了一个大坑。

首先我需要自定义一个函数:sign_f


import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
   @staticmethod
   def forward(ctx, inputs):
       output = inputs.new(inputs.size())
       output[inputs >= 0.] = 1
       output[inputs < 0.] = -1
       ctx.save_for_backward(inputs)
       return output

@staticmethod
   def backward(ctx, grad_output):
       input_, = ctx.saved_tensors
       grad_output[input_>1.] = 0
       grad_output[input_<-1.] = 0
       return grad_output

然后我需要把它封装为一个module 类型,就像 nn.Conv2d 模块 封装 f.conv2d 一样,于是


import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
# 我需要的module
   def __init__(self, *kargs, **kwargs):
       super(sign_, self).__init__(*kargs, **kwargs)

def forward(self, inputs):
   # 使用自定义函数
       outs = sign_f(inputs)
       return outs

class sign_f(Function):
   @staticmethod
   def forward(ctx, inputs):
       output = inputs.new(inputs.size())
       output[inputs >= 0.] = 1
       output[inputs < 0.] = -1
       ctx.save_for_backward(inputs)
       return output

@staticmethod
   def backward(ctx, grad_output):
       input_, = ctx.saved_tensors
       grad_output[input_>1.] = 0
       grad_output[input_<-1.] = 0
       return grad_output

结果报错

TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'

我试了半天,发现自定义函数后面要加 apply ,详细见下面


import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):

def __init__(self, *kargs, **kwargs):
       super(sign_, self).__init__(*kargs, **kwargs)
       self.r = sign_f.apply ### <-----注意此处

def forward(self, inputs):
       outs = self.r(inputs)
       return outs

class sign_f(Function):
   @staticmethod
   def forward(ctx, inputs):
       output = inputs.new(inputs.size())
       output[inputs >= 0.] = 1
       output[inputs < 0.] = -1
       ctx.save_for_backward(inputs)
       return output

@staticmethod
   def backward(ctx, grad_output):
       input_, = ctx.saved_tensors
       grad_output[input_>1.] = 0
       grad_output[input_<-1.] = 0
       return grad_output

问题解决了!

PyTorch自定义带学习参数的激活函数(如sigmoid)

有的时候我们需要给损失函数设一个超参数但是又不想设固定阈值想和网络一起自动学习,例如给Sigmoid一个参数alpha进行调节

pytorch自定义不可导激活函数的操作

pytorch自定义不可导激活函数的操作

函数如下:


import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
   def __init__(self, ):
       super(LearnableSigmoid, self).__init__()
       self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

self.reset_parameters()
   def reset_parameters(self):
       self.weight.data.fill_(1.0)

def forward(self, input):
       return 1/(1 +  torch.exp(-self.weight*input))

验证和Sigmoid的一致性


class LearnableSigmoid(nn.Module):
   def __init__(self, ):
       super(LearnableSigmoid, self).__init__()
       self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

self.reset_parameters()
   def reset_parameters(self):
       self.weight.data.fill_(1.0)

def forward(self, input):
       return 1/(1 +  torch.exp(-self.weight*input))

Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
       [0.4379, 0.1828, 0.4629],
       [0.4302, 0.1358, 0.4180]])

print(Sigmoid(input))
print(LearnSigmoid(input))

输出结果

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=<MulBackward0>)

验证权重是不是会更新


import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
   def __init__(self, ):
       super(LearnableSigmoid, self).__init__()
       self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

self.reset_parameters()

def reset_parameters(self):
       self.weight.data.fill_(1.0)

def forward(self, input):
       return 1/(1 +  torch.exp(-self.weight*input))

class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()      
       self.LSigmoid = LearnableSigmoid()
   def forward(self, x):                
       x = self.LSigmoid(x)
       return x

net = Net()  
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
   optimizer.zero_grad()    
   output = net(input_data)  
   loss = criterion(output, target)
   loss.backward()            
   optimizer.step()          
   print(list(net.parameters()))

输出结果

tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]

会更新~

来源:https://blog.csdn.net/qq_43110298/article/details/115032262

标签:pytorch,激活,函数
0
投稿

猜你喜欢

  • Python函数中定义参数的四种方式

    2021-10-11 03:10:20
  • Mysqlslap MySQL压力测试工具 简单教程

    2024-01-15 20:10:10
  • 利用Python实现读取Word文档里的Excel附件

    2022-01-21 11:28:18
  • php实现的支持断点续传的文件下载类

    2023-11-23 16:52:19
  • Python中字典及遍历常用函数的使用详解

    2021-06-25 13:06:03
  • Python 统计字数的思路详解

    2023-01-29 00:17:44
  • Vim中查找替换及正则表达式的使用详解

    2023-11-06 11:42:44
  • 关于go-zero服务自动收集问题分析

    2024-04-26 17:29:51
  • TensorFlow和keras中GPU使用的设置操作

    2023-08-07 20:32:53
  • canvas实现图片根据滑块放大缩小效果

    2024-04-16 09:52:15
  • 深入解析Python中的list列表及其切片和迭代操作

    2023-03-24 04:20:40
  • 如何将HTML字符转换为DOM节点并动态添加到文档中详解

    2023-08-23 12:26:39
  • MySQL的一些常用的SQL语句整理

    2024-01-19 06:38:40
  • Linux CentOS Python开发环境搭建教程

    2021-05-17 22:57:18
  • PHP header()函数使用详细(301、404等错误设置)

    2023-11-02 17:28:23
  • *.HTC 文件的简单介绍

    2008-11-24 17:36:00
  • asp是什么格式 asp文件用什么打开

    2020-06-30 16:04:48
  • 我们需要什么样的压力测试工具?

    2009-09-09 14:18:00
  • Python 字符串去除空格的五种方法

    2023-01-15 08:23:56
  • 按钮在 IE 中两边被拉伸的 BUG

    2008-11-17 20:37:00
  • asp之家 网络编程 m.aspxhome.com