在pytorch中如何查看模型model参数parameters
作者:xiaoju233 时间:2021-12-04 22:43:29
pytorch查看模型model参数parameters
示例1:pytorch自带的faster r-cnn模型
import torch
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
示例2:自定义网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.features = self._vgg_layers(cfg)
def _vgg_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)
]
in_channels = x
return nn.Sequential(*layers)
def forward(self, data):
out_map = self.features(data)
return out_map
Model = Net()
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
在自定义网络中,model.parameters()方法继承自nn.Module
pytorch查看模型参数总结
1:DNN_printer
其中(3, 32, 32)是输入的大小,其他方法中的参数同理
from DNN_printer import DNN_printer
batch_size = 512
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
// put the code here and you can get the result
DNN_printer(net, (3, 32, 32),batch_size)
结果
2:parameters
def cnn_paras_count(net):
"""cnn参数量统计, 使用方式cnn_paras_count(net)"""
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
return total_params, total_trainable_params
cnn_paras_count(net)
直接输出参数量,然后自己计算
需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。
例如:
44426个参数*4 / 1024 ≈ 174KB
3:get_model_complexity_info()
from ptflops import get_model_complexity_info
from torchvision import models
net = models.mobilenet_v2()
ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
4:torchstat
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
输出
来源:https://blog.csdn.net/qq_38600065/article/details/105552816
标签:pytorch,查看模型,model,parameters
0
投稿
猜你喜欢
Python调用Zoomeye搜索接口的实现
2021-08-26 03:59:24
javascript函数定义的几种区别小结
2024-04-10 10:40:03
python加速器numba使用详解
2022-02-27 15:24:22
js动态显示当前日期,时间和星期代码
2007-08-14 12:31:00
PyQt5实现用户登录GUI界面及登录后跳转
2021-04-08 07:50:03
如何使用Python自动生成报表并以邮件发送
2021-07-20 08:51:06
Python request post上传文件常见要点
2022-11-05 09:27:14
深入探讨opencv图像矫正算法实战
2022-06-03 16:20:39
参数传递解决window.open的session变量丢失
2007-10-22 17:40:00
python中的Reportlab模块详解最新推荐
2023-04-09 21:33:46
mysql 5.7.13 解压缩版(免安装)安装配置教程
2024-01-24 01:13:28
sql处理数据库锁的存储过程分享
2023-07-05 18:03:25
解析:怎样掌握SQL Server中的数据查询
2009-01-19 13:30:00
Python3.7 新特性之dataclass装饰器
2021-05-11 13:13:40
Python实现简单的代理服务器
2023-03-28 15:13:50
python:关于文件加载及处理方式
2021-03-12 10:17:30
go常用指令之go mod详解
2024-04-23 09:49:09
举例讲解Python常用模块
2022-03-21 07:35:49
在MySQL中使用通配符时应该注意的问题
2024-01-26 13:17:07
Python解析并读取PDF文件内容的方法
2021-07-13 20:06:21