pytorch 实现在预训练模型的 input上增减通道

作者:Hi_AI 时间:2023-12-02 00:49:33 

如何把imagenet预训练的模型,输入层的通道数随心所欲的修改,从而来适应自己的任务


#增加一个通道
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, w[:, :1, :, :]), dim=1))

#方式2
w = layers[0].weight
layers[0] = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(torch.cat((w, torch.zeros(64, 1, 7, 7)), dim=1))

#单通道输入
layers[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layers[0].weight = torch.nn.Parameter(w[:, :1, :, :])

来源:https://blog.csdn.net/github_36923418/article/details/84567227

标签:pytorch,预训练模型,增减,通道
0
投稿

猜你喜欢

  • pytorch tensorboard可视化的使用详解

    2022-09-27 01:01:51
  • IE中radio 或checkbox的checked属性初始状态下不能选中显示问题

    2024-05-10 14:06:42
  • Python中三个不可思议的返回功能分享

    2021-11-21 07:32:41
  • 详解mysql数据库增删改操作

    2024-01-15 02:23:05
  • 网页设计配色基础:RGB与HSB

    2008-05-06 12:23:00
  • golang网络通信超时设置方式

    2024-05-09 09:39:27
  • python中xrange和range的区别

    2023-03-14 05:54:35
  • MySQL在线开启或禁用GTID模式

    2024-01-24 01:13:52
  • asp如何在ADO中使用存储查询?

    2010-06-17 12:52:00
  • MySQL按天分组统计一定时间内的数据实例(没有数据补0)

    2024-01-17 07:42:08
  • windows下python模拟鼠标点击和键盘输示例

    2021-11-12 21:06:32
  • pandas分批读取大数据集教程

    2023-01-13 16:45:32
  • django 将model转换为字典的方法示例

    2022-09-16 14:03:09
  • Python+Opencv身份证号码区域提取及识别实现

    2021-10-01 17:32:13
  • 对python文件读写的缓冲行为详解

    2022-11-09 09:59:08
  • 使用python PIL库实现简单验证码的去噪方法步骤

    2022-05-05 00:48:46
  • 最新LOGO设计流行趋势——叶子

    2007-10-02 18:26:00
  • 解决node.js安装包失败的几种方法

    2024-05-08 09:36:34
  • Python PaddleNLP开源实现快递单信息抽取

    2023-01-21 04:35:11
  • Python正则表达式中的'r'用法总结

    2021-08-22 23:16:34
  • asp之家 网络编程 m.aspxhome.com