Python如何加载模型并查看网络

作者:ShuqiaoS 时间:2021-11-01 15:53:22 

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
   def __init__(self):
       super(test, self).__init__()
       self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

def forward(self, x):
       x = self.conv1(x)
       return x
model3 = torch.load("Test_method1.pth")

来源:https://blog.csdn.net/ShuqiaoS/article/details/90766761

标签:Python,加载模型,查看网络
0
投稿

猜你喜欢

  • 利用Pytorch实现获取特征图的方法详解

    2023-09-11 16:16:02
  • ASP 高级模板引擎实现类

    2011-03-25 10:54:00
  • 微软建议的ASP性能优化28条守则(1)

    2008-02-22 16:54:00
  • python保存图片时如何和原图大小一致

    2022-07-13 03:34:36
  • python给视频添加背景音乐并改变音量的具体方法

    2021-01-26 20:18:47
  • Python pygame 项目实战事件监听

    2023-05-31 21:33:20
  • asp 数据库连接函数代码

    2011-04-04 11:08:00
  • 常用原生js自定义函数总结

    2024-04-16 09:05:57
  • ORACLE 自动提交问题

    2023-07-24 10:43:13
  • Flash真的适合做网站应用吗?

    2011-04-16 10:34:00
  • tensorflow 动态获取 BatchSzie 的大小实例

    2023-03-05 16:56:48
  • 简单三步轻松实现ORACLE字段自增

    2024-01-16 06:06:58
  • js实现鼠标悬浮给图片加边框的方法

    2024-04-18 09:40:22
  • Python实现读取csv文件并进行排序

    2021-06-27 08:37:59
  • Vue指令v-for遍历输出JavaScript数组及json对象的常见方式小结

    2024-05-28 15:47:20
  • 模仿PHP写的ASP分页函数

    2008-04-13 06:11:00
  • Python通过RabbitMQ服务器实现交换机功能的实例教程

    2023-08-24 01:15:19
  • golang中strconv.ParseInt函数用法示例

    2024-04-23 09:46:48
  • 完美卸载Oracle数据库

    2024-01-18 02:50:27
  • 自学python求已知DNA模板的互补DNA序列

    2022-07-05 13:24:56
  • asp之家 网络编程 m.aspxhome.com