Pytorch实现简单自定义网络层的方法

作者:ting_qifengl 时间:2021-01-13 16:02:55 

前言

Pytorch、Tensoflow等许多深度学习框架集成了大量常见的网络层,为我们搭建神经网络提供了诸多便利。但在实际工作中,因为项目要求、研究需要或者 * 文需要等等,大家一般都会需要自己发明一个现在在深度学习框架中还不存在的层。 在这些情况下,就必须构建自定义层。

博主在学习了沐神的动手学深度学习这本书之后,学到了许多东西。这里记录一下书中基于Pytorch实现简单自定义网络层的方法,仅供参考。

一、不带参数的层

首先,我们构造一个没有任何参数的自定义层,要构建它,只需继承基础层类并实现前向传播功能。

import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
   def __init__(self):
       super().__init__()

def forward(self, X):
       return X - X.mean()

输入一些数据,验证一下网络是否能正常工作:

layer = CenteredLayer()
print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))

输出结果如下:

tensor([-2., -1.,  0.,  1.,  2.])

运行正常,表明网络没有问题。

现在将我们自建的网络层作为组件合并到更复杂的模型中,并输入数据进行验证:

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
print(Y.mean())  # 因为模型参数较多,输出也较多,所以这里输出Y的均值,验证模型可运行即可

结果如下:

tensor(-5.5879e-09, grad_fn=<MeanBackward0>)

二、带参数的层

这里使用内置函数来创建参数,这些函数可以提供一些基本的管理功能,使用更加方便。

这里实现了一个简单的自定义的全连接层,大家可根据需要自行修改即可。

class MyLinear(nn.Module):
   def __init__(self, in_units, units):
       super().__init__()
       self.weight = nn.Parameter(torch.randn(in_units, units))
       self.bias = nn.Parameter(torch.randn(units,))
   def forward(self, X):
       linear = torch.matmul(X, self.weight.data) + self.bias.data
       return F.relu(linear)

接下来实例化类并访问其模型参数:

linear = MyLinear(5, 3)
print(linear.weight)

结果如下:

Parameter containing:
tensor([[-0.3708,  1.2196,  1.3658],
        [ 0.4914, -0.2487, -0.9602],
        [ 1.8458,  0.3016, -0.3956],
        [ 0.0616, -0.3942,  1.6172],
        [ 0.7839,  0.6693, -0.8890]], requires_grad=True)

而后输入一些数据,查看模型输出结果:

print(linear(torch.rand(2, 5)))
# 结果如下
tensor([[1.2394, 0.0000, 0.0000],
       [1.3514, 0.0968, 0.6667]])

我们还可以使用自定义层构建模型,使用方法与使用内置的全连接层相同。

net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
print(net(torch.rand(2, 64)))
# 结果如下
tensor([[4.1416],
       [0.2567]])

三、总结

我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。

在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。

层可以有局部参数,这些参数可以通过内置函数创建。

四、参考

《动手学深度学习》 &mdash; 动手学深度学习 2.0.0-beta0 documentation

https://zh-v2.d2l.ai/

附:pytorch获取网络的层数和每层的名字

#创建自己的网络
import models
model = models.__dict__["resnet50"](pretrained=True)

for index ,(name, param) in enumerate(model.named_parameters()):
? ? print( str(index) + " " +name)

结果如下:

0 conv1.weight
1 bn1.weight
2 bn1.bias
3 layer1.0.conv1.weight
4 layer1.0.bn1.weight
5 layer1.0.bn1.bias
6 layer1.0.conv2.weight
7 layer1.0.bn2.weight
8 layer1.0.bn2.bias
9 layer1.0.conv3.weight

来源:https://blog.csdn.net/ting_qifengl/article/details/124870577

标签:pytorch,自定义,网络层
0
投稿

猜你喜欢

  • Pandas Matplotlib保存图形时坐标轴标签太长导致显示不全问题的解决

    2023-07-22 20:03:09
  • python 爬取腾讯视频评论的实现步骤

    2021-06-19 03:57:58
  • Django如何防止定时任务并发浅析

    2021-10-23 01:09:54
  • Python使用Matplotlib实现Logos设计代码

    2021-02-04 19:18:34
  • PHP常用函数和常见疑难问题解答

    2023-11-08 19:28:17
  • Python使用reportlab将目录下所有的文本文件打印成pdf的方法

    2022-01-23 11:36:11
  • Python基于locals返回作用域字典

    2021-05-17 22:02:43
  • DRF跨域后端解决之django-cors-headers的使用

    2021-10-08 20:12:32
  • python实现AI聊天机器人详解流程

    2022-12-11 23:57:37
  • python算法学习之基数排序实例

    2023-01-07 05:24:52
  • 三分钟时间教你用Python绘制春联

    2023-11-06 00:26:08
  • Pytorch上下采样函数之F.interpolate数组采样操作详解

    2022-01-19 08:13:51
  • 自然语言处理NLP TextRNN实现情感分类

    2022-01-20 11:14:47
  • 使用virtualenv创建Python环境及PyQT5环境配置的方法

    2022-12-30 06:09:26
  • Django组件content-type使用方法详解

    2023-10-01 13:54:42
  • 对Server.UrlEncode进行字符反编译

    2009-06-22 12:54:00
  • css命名及书写规范大全

    2008-05-24 08:52:00
  • python批量翻译excel表格中的英文

    2022-11-16 08:09:07
  • Python编程求质数实例代码

    2021-12-03 23:17:18
  • 500行Python代码打造刷脸考勤系统

    2022-01-21 12:54:10
  • asp之家 网络编程 m.aspxhome.com