用pytorch的nn.Module构造简单全链接层实例
作者:AItitanic 时间:2022-01-04 00:00:17
python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架
1、先定义一个类Linear,继承nn.Module
import torch as t
from torch import nn
from torch.autograd import Variable as V
class Linear(nn.Module):
'''因为Variable自动求导,所以不需要实现backward()'''
def __init__(self, in_features, out_features):
super().__init__()
self.w = nn.Parameter( t.randn( in_features, out_features ) ) #权重w 注意Parameter是一个特殊的Variable
self.b = nn.Parameter( t.randn( out_features ) ) #偏值b
def forward( self, x ): #参数 x 是一个Variable对象
x = x.mm( self.w )
return x + self.b.expand_as( x ) #让b的形状符合 输出的x的形状
2、验证一下
layer = Linear( 4,3 )
input = V ( t.randn( 2 ,4 ) )#包装一个Variable作为输入
out = layer( input )
out
#成功运行,结果如下:
tensor([[-2.1934, 2.5590, 4.0233], [ 1.1098, -3.8182, 0.1848]], grad_fn=<AddBackward0>)
下面利用Linear构造一个多层网络
class Perceptron( nn.Module ):
def __init__( self,in_features, hidden_features, out_features ):
super().__init__()
self.layer1 = Linear( in_features , hidden_features )
self.layer2 = Linear( hidden_features, out_features )
def forward ( self ,x ):
x = self.layer1( x )
x = t.sigmoid( x ) #用sigmoid()激活函数
return self.layer2( x )
测试一下
perceptron = Perceptron ( 5,3 ,1 )
for name,param in perceptron.named_parameters():
print( name, param.size() )
输出如预期:
layer1.w torch.Size([5, 3])
layer1.b torch.Size([3])
layer2.w torch.Size([3, 1])
layer2.b torch.Size([1])
来源:https://blog.csdn.net/AItitanic/article/details/97611356
标签:pytorch,nn.Module,全链接层
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
python GUI计算器的实现
2021-11-30 01:00:53
![](https://img.aspxhome.com/file/2023/8/87948_0s.png)
Oracle 数据显示 横表转纵表
2009-07-26 08:57:00
WEB前端开发高性能优化之JavaScript优化细节
2009-06-10 14:38:00
Perl split字符串分割函数用法指南
2023-08-13 01:28:36
详解如何用SQLyog来分析MySQL数据库
2008-10-13 12:35:00
![](https://img.aspxhome.com/file/UploadPic/200810/13/SQLyog-mysql_15s.jpg)
Ubuntu下使用Python实现游戏制作中的切分图片功能
2021-02-22 22:55:53
![](https://img.aspxhome.com/file/2023/3/82783_0s.png)
asp fso type属性取得文件类型代码
2009-02-04 10:09:00
转换字符串单词的第一个字母为大写
2007-10-18 10:50:00
Access2K中的查询分析器
2008-11-20 16:40:00
关于多元线性回归分析——Python&SPSS
2023-03-11 17:03:34
![](https://img.aspxhome.com/file/2023/0/67230_0s.jpg)
Python如何输出警告信息
2022-01-25 23:34:44
网马解密大讲堂——网马解密中级篇(Freshow工具使用方法)
2009-09-16 15:09:00
![](https://img.aspxhome.com/file/UploadPic/20099/16/freshow-13s.jpg)
教你快速掌握怎样在Windows下升级MySQL
2008-12-31 17:08:00
Python 使用多属性来进行排序
2023-11-10 21:15:07
Go语言二进制文件的读写操作
2023-06-23 09:40:08
关于分页查询和性能问题
2008-03-11 12:25:00
一文详解go mod依赖管理详情
2023-07-13 04:35:06
![](https://img.aspxhome.com/file/2023/2/99062_0s.png)
oracle日期分组查询的完整实例
2023-06-26 10:14:13
![](https://img.aspxhome.com/file/2023/4/63374_0s.png)
详细讲解Access数据库远程连接的实用方法
2008-11-28 16:34:00
java正则表达式解析html示例分享
2023-06-13 15:53:42