pytorch 如何自定义卷积核权值参数

作者:Mr.Jcak 时间:2021-10-30 19:10:22 

pytorch中构建卷积层一般使用nn.Conv2d方法,有些情况下我们需要自定义卷积核的权值weight,而nn.Conv2d中的卷积参数是不允许自定义的,此时可以使用torch.nn.functional.conv2d简称F.conv2d


torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

F.conv2d可以自己输入且也必须要求自己输入卷积权值weight和偏置bias。因此,构建自己想要的卷积核参数,再输入F.conv2d即可。

下面是一个用F.conv2d构建卷积层的例子

这里为了网络模型需要写成了一个类:


class CNN(nn.Module):
   def __init__(self):
       super(CNN, self).__init__()
       self.weight = nn.Parameter(torch.randn(16, 1, 5, 5))  # 自定义的权值
       self.bias = nn.Parameter(torch.randn(16))    # 自定义的偏置

def forward(self, x):
       x = x.view(x.size(0), -1)
       out = F.conv2d(x, self.weight, self.bias, stride=1, padding=0)
       return out

值得注意的是,pytorch中各层需要训练的权重的数据类型设为nn.Parameter,而不是Tensor或者Variable。parameter的require_grad默认设置为true,而Varaible默认设置为False。

补充:pytorch中卷积参数的理解

pytorch 如何自定义卷积核权值参数

pytorch 如何自定义卷积核权值参数

pytorch 如何自定义卷积核权值参数

kernel_size代表着卷积核,例如kernel_size=3或kernel_size=(3,7);

stride:表明卷积核在像素级图像上行走的步长,如图2,步长为1;

padding:为上下左右填充的大小,例如padding=0/1/(1,1)/(1,3),

padding=0 不填充;

padding=1/(1,1) 上下左右分别填充1个格;

padding=(1,3) 高(上下)填充2个格,宽(左右)填充6个格;

卷积代码


torch.nn.Conv2d(512,512,kernel_size=(3,7),stride=2,padding=1)

指定输出形状的上采样


def upsample_add(self,x,y):
       _,_,H,W = y.size()
       return F.interpolate(x, size=(H,W), mode='bilinear', align_corners=False) + y

反卷积上采样


output_shape_w=kernel_size_w+(output_w-1)(kernel_size_w-1)+2padding

self.upscore2 = nn.ConvTranspose2d(
           512, 1, kernel_size=3, stride=2,padding=0, bias=False)

来源:https://blog.csdn.net/weixin_38314865/article/details/105941140

标签:pytorch,卷积核,权值,参数
0
投稿

猜你喜欢

  • MYsql库与表的管理及视图介绍

    2024-01-25 21:33:06
  • 详解如何用Python登录豆瓣并爬取影评

    2021-09-08 00:10:10
  • Pytorch 实现冻结指定卷积层的参数

    2023-05-22 07:27:21
  • python入门学习笔记分享

    2023-01-29 17:46:16
  • 十个Python练手的实战项目,学会这些Python就基本没问题了(推荐)

    2022-07-21 04:55:52
  • MySQL如何统计一个数据库所有表的数据量

    2024-01-23 20:07:14
  • python通过对字典的排序,对json字段进行排序的实例

    2023-06-15 02:20:40
  • java连接mysql数据库乱码的解决方法

    2024-01-21 06:26:15
  • vue3中关于路由hash与History的设置

    2024-05-13 09:14:24
  • 如何让Python在HTML中运行

    2023-06-13 08:21:28
  • Python绘制地理图表可视化神器pyecharts

    2021-01-22 18:08:58
  • Python实现Telnet自动连接检测密码的示例

    2021-10-05 11:08:37
  • python数字图像处理环境安装与配置过程示例

    2023-03-05 07:00:25
  • 浅析mysql 定时备份任务

    2024-01-17 07:21:01
  • Python简单实现Base64编码和解码的方法

    2022-05-14 12:22:35
  • 机器学习10大经典算法详解

    2021-02-21 01:39:57
  • 浅谈利用numpy对矩阵进行归一化处理的方法

    2021-10-12 01:22:41
  • php上传大文件设置方法

    2023-11-21 19:11:22
  • Python中使用logging模块打印log日志详解

    2021-10-01 02:32:17
  • python 集合常用操作汇总

    2023-11-15 08:41:22
  • asp之家 网络编程 m.aspxhome.com