在pytorch中如何查看模型model参数parameters

作者:xiaoju233 时间:2021-12-04 22:43:29 

pytorch查看模型model参数parameters

示例1:pytorch自带的faster r-cnn模型

import torch
import torchvision

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

for name, p in model.named_parameters():
   print(name)
   print(p.requires_grad)
   print(...)

#或者

for p in model.parameters():
   print(p)
   print(...)

示例2:自定义网络模型

class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()

cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
       self.features = self._vgg_layers(cfg)

def _vgg_layers(self, cfg):
       layers = []
       in_channels = 3
       for x in cfg:
           if x == 'M':
               layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
           else:
               layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
                       nn.BatchNorm2d(x),
                       nn.ReLU(inplace=True)
                       ]
               in_channels = x

return nn.Sequential(*layers)

def forward(self, data):
       out_map = self.features(data)
       return out_map

Model = Net()

for name, p in model.named_parameters():
   print(name)
   print(p.requires_grad)
   print(...)

#或者

for p in model.parameters():
   print(p)
   print(...)

在自定义网络中,model.parameters()方法继承自nn.Module

pytorch查看模型参数总结

1:DNN_printer

其中(3, 32, 32)是输入的大小,其他方法中的参数同理

from DNN_printer import DNN_printer

batch_size = 512
def train(epoch):
   print('\nEpoch: %d' % epoch)
   net.train()
   train_loss = 0
   correct = 0
   total = 0
   // put the code here and you can get the result
   DNN_printer(net, (3, 32, 32),batch_size)

结果

在pytorch中如何查看模型model参数parameters

2:parameters

def cnn_paras_count(net):
   """cnn参数量统计, 使用方式cnn_paras_count(net)"""
   # Find total parameters and trainable parameters
   total_params = sum(p.numel() for p in net.parameters())
   print(f'{total_params:,} total parameters.')
   total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
   print(f'{total_trainable_params:,} training parameters.')
   return total_params, total_trainable_params

cnn_paras_count(net)

直接输出参数量,然后自己计算

需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。

例如:

  • 44426个参数*4 / 1024 ≈ 174KB

3:get_model_complexity_info()

from ptflops import get_model_complexity_info
from torchvision import models

net = models.mobilenet_v2()
ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)

在pytorch中如何查看模型model参数parameters

4:torchstat

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

输出

在pytorch中如何查看模型model参数parameters

来源:https://blog.csdn.net/qq_38600065/article/details/105552816

标签:pytorch,查看模型,model,parameters
0
投稿

猜你喜欢

  • ASP真正随机不重复查询代码

    2010-01-02 20:40:00
  • 原生JS实现左右箭头选择日期实例代码

    2023-08-06 04:55:27
  • pytorch中dataloader 的sampler 参数详解

    2023-09-16 21:00:13
  • Python 分支结构详解

    2021-03-17 01:43:06
  • php floor()函数案例详解

    2023-06-14 16:13:03
  • python使用PIL模块获取图片像素点的方法

    2022-07-28 10:57:57
  • PHP中SESSION使用中的一点经验总结

    2023-11-19 11:48:54
  • Dojo Style Javascript 编程规范

    2007-10-25 17:24:00
  • Python爬取几千条相亲文案

    2023-01-19 22:59:56
  • 详解Python3中yield生成器的用法

    2021-09-03 05:59:27
  • IPython 8.0 Python 命令行交互工具

    2022-10-24 09:17:54
  • python函数装饰器构造和参数传递

    2023-05-24 16:49:17
  • 教你用Python查看茅台股票交易数据的详细代码

    2022-06-05 13:36:16
  • 对Python中class和instance以及self的用法详解

    2022-09-08 23:28:14
  • 关于JSON以及JSON在PHP中的应用技巧

    2023-11-16 00:03:38
  • python常见的占位符总结及用法

    2023-10-11 10:39:58
  • python爬虫的一个常见简单js反爬详解

    2022-10-26 11:46:27
  • Python3 集合set入门基础

    2021-04-25 10:20:46
  • 详解用python实现基本的学生管理系统(文件存储版)(python3)

    2021-10-08 18:26:29
  • python中函数的返回值及类型详解

    2023-02-17 10:58:11
  • asp之家 网络编程 m.aspxhome.com