PyTorch深度学习模型的保存和加载流程详解

作者:软耳朵DONG 时间:2023-07-10 04:58:33 

一、模型参数的保存和加载

  •  torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt.pth.pkl)。

  • torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 。

  • torch.nn.Module.state_dict()函数返回python中的一个OrderedDict类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict中,例如:卷积层、线性层等。

  • Python中的字典类以“键:值”方式存取数据,OrderedDict是它的一个子类,实现了对字典对象中元素的排序(OrderedDict根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict字典对象会被当做是两个不同的对象。

  • 示例:


import torch
import torch.nn as nn

class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()
       self.conv1 = nn.Conv2d(1, 2, 3)
       self.pool1 = nn.MaxPool2d(2, 2)

def forward(self, x):
       x = self.conv1(x)
       x = self.pool1(x)
       return x

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 获取state_dict
state_dict = net.state_dict()
# 字典的遍历默认是遍历key,所以param_tensor实际上是键值
for param_tensor in state_dict:
   print(param_tensor,':\n',state_dict[param_tensor])
# 保存模型参数
torch.save(state_dict,"net_params.pth")
# 通过加载state_dict获取模型参数
net.load_state_dict(state_dict)

输出:

PyTorch深度学习模型的保存和加载流程详解

二、完整模型的保存和加载

  •  torch.save(module, path):将训练完的整个网络模型module保存到path所指定的文件存放路径(常用文件格式为.pt.pth)。

  • torch.load(path):加载保存到path中的整个神经网络模型。

  • 示例:


import torch
import torch.nn as nn

class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()
       self.conv1 = nn.Conv2d(1, 2, 3)
       self.pool1 = nn.MaxPool2d(2, 2)

def forward(self, x):
       x = self.conv1(x)
       x = self.pool1(x)
       return x

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整个网络
torch.save(net,"net.pth")
# 加载网络
net = torch.load("net.pth")

来源:https://blog.csdn.net/m0_52650517/article/details/120836999

标签:PyTorch,模型的保存,模型的加载
0
投稿

猜你喜欢

  • arcgis使用Python脚本进行批量截图功能实现

    2021-04-25 03:40:05
  • 不到20行实现Python代码即可制作精美证件照

    2021-08-29 09:27:43
  • Python对XML文件实现增删改查操作

    2023-11-19 20:42:03
  • MySQL最基本的命令使用汇总

    2024-01-28 06:18:30
  • html注释所引起的一系列问题

    2008-11-04 13:23:00
  • Python中Numpy和Matplotlib的基本使用指南

    2021-10-26 04:22:44
  • Django利用LogEntry生成历史操作实战记录

    2021-10-23 04:23:58
  • 深入理解Python虚拟机中字节(bytes)的实现原理及源码剖析

    2021-12-20 22:51:28
  • 使用Django和Postgres进行全文搜索的实例代码

    2022-07-06 10:52:15
  • 使用pytorch提取卷积神经网络的特征图可视化

    2023-02-01 20:32:30
  • Python编程语言的35个与众不同之处(语言特征和使用技巧)

    2023-11-21 23:09:25
  • JavaScript每天必学之事件

    2024-04-22 13:08:43
  • 相同记录行如何取最大值

    2008-07-26 12:32:00
  • Python实现同时兼容老版和新版Socket协议的一个简单WebSocket服务器

    2023-05-21 10:05:28
  • Django实现分页功能

    2023-04-04 11:00:56
  • Python使用smtplib模块发送电子邮件的流程详解

    2023-09-28 03:28:35
  • python简单实现旋转图片的方法

    2021-06-09 08:29:11
  • 微信小程序(订阅消息)功能

    2024-04-28 09:36:48
  • php面向对象程序设计介绍

    2023-05-25 05:31:11
  • mysql中数据库覆盖导入的几种方式总结

    2024-01-19 22:26:33
  • asp之家 网络编程 m.aspxhome.com