解决pytorch 的state_dict()拷贝问题

作者:Luke_Ye 时间:2022-10-05 22:03:57 

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了


import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
 def __init__(self):
   super(TheModelClass, self).__init__()
   self.conv1 = nn.Conv2d(3, 6, 5)
   self.pool = nn.MaxPool2d(2, 2)
   self.conv2 = nn.Conv2d(6, 16, 5)
   self.fc1 = nn.Linear(16 * 5 * 5, 120)
   self.fc2 = nn.Linear(120, 84)
   self.fc3 = nn.Linear(84, 10)
 def forward(self, x):
   x = self.pool(F.relu(self.conv1(x)))
   x = self.pool(F.relu(self.conv2(x)))
   x = x.view(-1, 16 * 5 * 5)
   x = F.relu(self.fc1(x))
   x = F.relu(self.fc2(x))
   x = self.fc3(x)
   return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
 print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:


Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:


net=BaseNet()

保存net时报错 object has no attribute 'state_dict'


torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:


class BaseNet(object):
 def __init__(self):

把类定义改为


class BaseNet(nn.Module):
 def __init__(self):
   super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

来源:https://www.cnblogs.com/LukeStepByStep/p/11248361.html

标签:pytorch,state,dict,拷贝
0
投稿

猜你喜欢

  • 基于OpenCV目标跟踪实现人员计数器

    2022-11-17 15:04:03
  • 如何运用python读写CSV文件

    2021-11-13 04:35:36
  • golang的协程上下文的具体使用

    2024-02-01 00:41:02
  • PHP四舍五入精确小数位及取整

    2024-05-21 10:20:36
  • Python中的fileinput模块的简单实用示例

    2023-06-19 01:09:27
  • 基于PyQt5制作一个群发邮件工具

    2022-09-04 01:46:46
  • Golang range slice 与range array 之间的区别

    2024-05-21 10:26:31
  • Django模板之基本的 for 循环 和 List内容的显示方式

    2021-09-24 05:18:24
  • python开启摄像头以及深度学习实现目标检测方法

    2023-10-27 03:23:18
  • Windows下python3安装tkinter的问题及解决方法

    2023-03-30 11:45:31
  • Python中的chr()函数与ord()函数解析

    2021-10-21 13:19:26
  • 如何用ASP获知机器的网络配置?

    2010-06-11 19:58:00
  • python批量提交沙箱问题实例

    2023-12-14 07:40:09
  • 用Python在Excel里画出蒙娜丽莎的方法示例

    2023-12-18 02:59:21
  • python设置中文界面实例方法

    2023-08-30 18:56:30
  • WML初级教程之从实际应用中了解WML

    2008-09-04 11:24:00
  • Pytest+Request+Allure+Jenkins实现接口自动化

    2021-04-09 13:50:44
  • python基于三阶贝塞尔曲线的数据平滑算法

    2022-04-19 18:23:06
  • python client使用http post 到server端的代码

    2021-09-03 14:33:54
  • 抛砖:如何进行互联网项目开发

    2010-01-25 12:25:00
  • asp之家 网络编程 m.aspxhome.com