Pytorch 实现自定义参数层的例子
作者:青盏 时间:2023-01-27 22:00:06
注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。
官方Linear层:
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
return F.linear(input, self.weight, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
实现view层
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view((x.size(0),)+self.shape)
实现LinearWise层
class LinearWise(nn.Module):
def __init__(self, in_features, bias=True):
super(LinearWise, self).__init__()
self.in_features = in_features
self.weight = nn.Parameter(torch.Tensor(self.in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(self.in_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(0))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
x = input * self.weight
if self.bias is not None:
x = x + self.bias
return x
来源:https://blog.csdn.net/qq_16234613/article/details/81604081
标签:Pytorch,自定义,参数层
0
投稿
猜你喜欢
关于Python正则表达式 findall函数问题详解
2022-10-24 22:18:43
网络编程之get与post的区别与联系
2023-01-01 09:40:37
举例讲解如何在Python编程中进行迭代和遍历
2023-07-12 04:42:30
pycharm 实现光标快速移动到括号外或行尾的操作
2023-07-17 19:52:31
Python Numpy学习之索引及切片的使用方法
2021-09-04 02:59:01
10个ASP网页制作技巧
2007-09-24 13:12:00
python opencv实现信用卡的数字识别
2023-07-05 02:20:23
php测试kafka项目示例
2023-11-19 20:40:04
教你为SQL Server数据库构造安全门
2009-01-20 11:34:00
详解Python如何实现尾递归优化
2023-11-13 04:20:06
Golang数据类型比较详解
2023-07-17 10:11:21
Node.js的非阻塞I/O、异步与事件驱动介绍
2024-05-13 09:35:02
asp利用XmlHttp和Adodb.Stream采集图片
2007-12-06 18:42:00
Pytorch之parameters的使用
2022-05-22 21:06:01
jquery和css3中的选择器nth-child使用方法和用途示例
2024-04-25 13:11:35
JavaScript的replace方法与正则表达式结合应用讲解
2008-03-06 21:37:00
脚本安全的本质_PHP+MYSQL第1/3页
2023-11-23 23:54:45
如何使用Python在2秒内评估国际象棋位置详解
2023-08-10 14:26:46
php中支持多种编码的中文字符串截取函数!
2023-09-27 02:08:15
python绘制横向水平柱状条形图
2022-01-10 01:01:49