用pytorch的nn.Module构造简单全链接层实例

作者:AItitanic 时间:2022-01-04 00:00:17 

python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架

1、先定义一个类Linear,继承nn.Module


import torch as t
from torch import nn
from torch.autograd import Variable as V

class Linear(nn.Module):

'''因为Variable自动求导,所以不需要实现backward()'''
 def __init__(self, in_features, out_features):
   super().__init__()
   self.w = nn.Parameter( t.randn( in_features, out_features ) ) #权重w 注意Parameter是一个特殊的Variable
   self.b = nn.Parameter( t.randn( out_features ) )   #偏值b

def forward( self, x ): #参数 x 是一个Variable对象
   x = x.mm( self.w )
   return x + self.b.expand_as( x ) #让b的形状符合 输出的x的形状

2、验证一下


layer = Linear( 4,3 )
input = V ( t.randn( 2 ,4 ) )#包装一个Variable作为输入
out = layer( input )
out

#成功运行,结果如下:

tensor([[-2.1934, 2.5590, 4.0233], [ 1.1098, -3.8182, 0.1848]], grad_fn=<AddBackward0>)

下面利用Linear构造一个多层网络


class Perceptron( nn.Module ):
 def __init__( self,in_features, hidden_features, out_features ):
   super().__init__()
   self.layer1 = Linear( in_features , hidden_features )
   self.layer2 = Linear( hidden_features, out_features )
 def forward ( self ,x ):
   x = self.layer1( x )
   x = t.sigmoid( x ) #用sigmoid()激活函数
   return self.layer2( x )

测试一下


perceptron = Perceptron ( 5,3 ,1 )

for name,param in perceptron.named_parameters():
 print( name, param.size() )

输出如预期:


layer1.w torch.Size([5, 3])
layer1.b torch.Size([3])
layer2.w torch.Size([3, 1])
layer2.b torch.Size([1])

来源:https://blog.csdn.net/AItitanic/article/details/97611356

标签:pytorch,nn.Module,全链接层
0
投稿

猜你喜欢

  • python GUI计算器的实现

    2021-11-30 01:00:53
  • Oracle 数据显示 横表转纵表

    2009-07-26 08:57:00
  • WEB前端开发高性能优化之JavaScript优化细节

    2009-06-10 14:38:00
  • Perl split字符串分割函数用法指南

    2023-08-13 01:28:36
  • 详解如何用SQLyog来分析MySQL数据库

    2008-10-13 12:35:00
  • Ubuntu下使用Python实现游戏制作中的切分图片功能

    2021-02-22 22:55:53
  • asp fso type属性取得文件类型代码

    2009-02-04 10:09:00
  • 转换字符串单词的第一个字母为大写

    2007-10-18 10:50:00
  • Access2K中的查询分析器

    2008-11-20 16:40:00
  • 关于多元线性回归分析——Python&SPSS

    2023-03-11 17:03:34
  • Python如何输出警告信息

    2022-01-25 23:34:44
  • 网马解密大讲堂——网马解密中级篇(Freshow工具使用方法)

    2009-09-16 15:09:00
  • 教你快速掌握怎样在Windows下升级MySQL

    2008-12-31 17:08:00
  • Python 使用多属性来进行排序

    2023-11-10 21:15:07
  • Go语言二进制文件的读写操作

    2023-06-23 09:40:08
  • 关于分页查询和性能问题

    2008-03-11 12:25:00
  • 一文详解go mod依赖管理详情

    2023-07-13 04:35:06
  • oracle日期分组查询的完整实例

    2023-06-26 10:14:13
  • 详细讲解Access数据库远程连接的实用方法

    2008-11-28 16:34:00
  • java正则表达式解析html示例分享

    2023-06-13 15:53:42
  • asp之家 网络编程 m.aspxhome.com