Pytorch 实现自定义参数层的例子

作者:青盏 时间:2023-01-27 22:00:06 

注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。

官方Linear层:


class Linear(Module):
 def __init__(self, in_features, out_features, bias=True):
   super(Linear, self).__init__()
   self.in_features = in_features
   self.out_features = out_features
   self.weight = Parameter(torch.Tensor(out_features, in_features))
   if bias:
     self.bias = Parameter(torch.Tensor(out_features))
   else:
     self.register_parameter('bias', None)
   self.reset_parameters()

def reset_parameters(self):
   stdv = 1. / math.sqrt(self.weight.size(1))
   self.weight.data.uniform_(-stdv, stdv)
   if self.bias is not None:
     self.bias.data.uniform_(-stdv, stdv)

def forward(self, input):
   return F.linear(input, self.weight, self.bias)

def extra_repr(self):
   return 'in_features={}, out_features={}, bias={}'.format(
     self.in_features, self.out_features, self.bias is not None
   )

实现view层


class Reshape(nn.Module):
 def __init__(self, *args):
   super(Reshape, self).__init__()
   self.shape = args

def forward(self, x):
   return x.view((x.size(0),)+self.shape)

实现LinearWise层


class LinearWise(nn.Module):
 def __init__(self, in_features, bias=True):
   super(LinearWise, self).__init__()
   self.in_features = in_features

self.weight = nn.Parameter(torch.Tensor(self.in_features))
   if bias:
     self.bias = nn.Parameter(torch.Tensor(self.in_features))
   else:
     self.register_parameter('bias', None)
   self.reset_parameters()

def reset_parameters(self):
   stdv = 1. / math.sqrt(self.weight.size(0))
   self.weight.data.uniform_(-stdv, stdv)
   if self.bias is not None:
     self.bias.data.uniform_(-stdv, stdv)

def forward(self, input):
   x = input * self.weight
   if self.bias is not None:
     x = x + self.bias
   return x

来源:https://blog.csdn.net/qq_16234613/article/details/81604081

标签:Pytorch,自定义,参数层
0
投稿

猜你喜欢

  • 关于Python正则表达式 findall函数问题详解

    2022-10-24 22:18:43
  • 网络编程之get与post的区别与联系

    2023-01-01 09:40:37
  • 举例讲解如何在Python编程中进行迭代和遍历

    2023-07-12 04:42:30
  • pycharm 实现光标快速移动到括号外或行尾的操作

    2023-07-17 19:52:31
  • Python Numpy学习之索引及切片的使用方法

    2021-09-04 02:59:01
  • 10个ASP网页制作技巧

    2007-09-24 13:12:00
  • python opencv实现信用卡的数字识别

    2023-07-05 02:20:23
  • php测试kafka项目示例

    2023-11-19 20:40:04
  • 教你为SQL Server数据库构造安全门

    2009-01-20 11:34:00
  • 详解Python如何实现尾递归优化

    2023-11-13 04:20:06
  • Golang数据类型比较详解

    2023-07-17 10:11:21
  • Node.js的非阻塞I/O、异步与事件驱动介绍

    2024-05-13 09:35:02
  • asp利用XmlHttp和Adodb.Stream采集图片

    2007-12-06 18:42:00
  • Pytorch之parameters的使用

    2022-05-22 21:06:01
  • jquery和css3中的选择器nth-child使用方法和用途示例

    2024-04-25 13:11:35
  • JavaScript的replace方法与正则表达式结合应用讲解

    2008-03-06 21:37:00
  • 脚本安全的本质_PHP+MYSQL第1/3页

    2023-11-23 23:54:45
  • 如何使用Python在2秒内评估国际象棋位置详解

    2023-08-10 14:26:46
  • php中支持多种编码的中文字符串截取函数!

    2023-09-27 02:08:15
  • python绘制横向水平柱状条形图

    2022-01-10 01:01:49
  • asp之家 网络编程 m.aspxhome.com