在pytorch中如何查看模型model参数parameters
作者:xiaoju233 时间:2021-12-04 22:43:29
pytorch查看模型model参数parameters
示例1:pytorch自带的faster r-cnn模型
import torch
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
示例2:自定义网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.features = self._vgg_layers(cfg)
def _vgg_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)
]
in_channels = x
return nn.Sequential(*layers)
def forward(self, data):
out_map = self.features(data)
return out_map
Model = Net()
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
在自定义网络中,model.parameters()方法继承自nn.Module
pytorch查看模型参数总结
1:DNN_printer
其中(3, 32, 32)是输入的大小,其他方法中的参数同理
from DNN_printer import DNN_printer
batch_size = 512
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
// put the code here and you can get the result
DNN_printer(net, (3, 32, 32),batch_size)
结果
2:parameters
def cnn_paras_count(net):
"""cnn参数量统计, 使用方式cnn_paras_count(net)"""
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
return total_params, total_trainable_params
cnn_paras_count(net)
直接输出参数量,然后自己计算
需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。
例如:
44426个参数*4 / 1024 ≈ 174KB
3:get_model_complexity_info()
from ptflops import get_model_complexity_info
from torchvision import models
net = models.mobilenet_v2()
ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
4:torchstat
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
输出
来源:https://blog.csdn.net/qq_38600065/article/details/105552816
标签:pytorch,查看模型,model,parameters
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
ASP真正随机不重复查询代码
2010-01-02 20:40:00
原生JS实现左右箭头选择日期实例代码
2023-08-06 04:55:27
![](https://img.aspxhome.com/file/2023/8/55938_0s.png)
pytorch中dataloader 的sampler 参数详解
2023-09-16 21:00:13
Python 分支结构详解
2021-03-17 01:43:06
![](https://img.aspxhome.com/file/2023/4/72084_0s.png)
php floor()函数案例详解
2023-06-14 16:13:03
python使用PIL模块获取图片像素点的方法
2022-07-28 10:57:57
PHP中SESSION使用中的一点经验总结
2023-11-19 11:48:54
![](https://img.aspxhome.com/file/2023/8/98088_0s.jpg)
Dojo Style Javascript 编程规范
2007-10-25 17:24:00
Python爬取几千条相亲文案
2023-01-19 22:59:56
![](https://img.aspxhome.com/file/2023/0/118050_0s.png)
详解Python3中yield生成器的用法
2021-09-03 05:59:27
IPython 8.0 Python 命令行交互工具
2022-10-24 09:17:54
![](https://img.aspxhome.com/file/2023/0/87720_0s.png)
python函数装饰器构造和参数传递
2023-05-24 16:49:17
教你用Python查看茅台股票交易数据的详细代码
2022-06-05 13:36:16
![](https://img.aspxhome.com/file/2023/3/107093_0s.jpg)
对Python中class和instance以及self的用法详解
2022-09-08 23:28:14
关于JSON以及JSON在PHP中的应用技巧
2023-11-16 00:03:38
python常见的占位符总结及用法
2023-10-11 10:39:58
python爬虫的一个常见简单js反爬详解
2022-10-26 11:46:27
Python3 集合set入门基础
2021-04-25 10:20:46
![](https://img.aspxhome.com/file/2023/2/105342_0s.png)
详解用python实现基本的学生管理系统(文件存储版)(python3)
2021-10-08 18:26:29
python中函数的返回值及类型详解
2023-02-17 10:58:11
![](https://img.aspxhome.com/file/2023/7/91897_0s.jpg)