pytorch 状态字典:state_dict使用详解

作者:wzg2016 时间:2023-01-16 11:42:52 

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"


torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用


model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令


torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令


model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令


torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:


  # Model class must be defined somewhere

model = torch.load(PATH)

model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.


conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)


for param in list(model.pretrained.parameters()):
param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:


#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# define model
class TheModelClass(nn.Module):
def __init__(self):
 super(TheModelClass,self).__init__()
 self.conv1 = nn.Conv2d(3,6,5)
 self.pool = nn.MaxPool2d(2,2)
 self.conv2 = nn.Conv2d(6,16,5)
 self.fc1 = nn.Linear(16*5*5,120)
 self.fc2 = nn.Linear(120,84)
 self.fc3 = nn.Linear(84,10)

def forward(self,x):
 x = self.pool(F.relu(self.conv1(x)))
 x = self.pool(F.relu(self.conv2(x)))
 x = x.view(-1,16*5*5)
 x = F.relu(self.fc1(x))
 x = F.relu(self.fc2(x))
 x = self.fc3(x)
 return x

# initial model
model = TheModelClass()

#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())

print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])

print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)

print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)

model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

来源:https://blog.csdn.net/Strive_For_Future/article/details/83240081

标签:pytorch,字典,state,dict
0
投稿

猜你喜欢

  • 用Python展示动态规则法用以解决重叠子问题的示例

    2023-02-09 02:20:36
  • 详解pyinstaller selenium python3 chrome打包问题

    2023-08-12 08:26:40
  • Jsp+Servlet实现文件上传下载 删除上传文件(三)

    2023-06-27 16:29:29
  • axios请求二次封装之避免重复发送请求

    2024-04-09 10:45:32
  • git rebase 成功之后撤销的操作方法

    2022-08-27 17:39:47
  • 使用Python的Django框架结合jQuery实现AJAX购物车页面

    2023-05-21 01:59:28
  • 浅谈function(函数)中的动态参数

    2023-08-11 10:23:59
  • mysql实现设置定时任务的方法分析

    2024-01-18 03:37:18
  • Python+PyQT5实现手绘图片生成器

    2022-03-11 11:57:21
  • python matplotlib如何给图中的点加标签

    2023-02-23 12:16:47
  • 微信小程序跳一跳游戏 python脚本跳一跳刷高分技巧

    2023-01-31 22:30:44
  • JS图片懒加载的优点及实现原理

    2024-04-18 09:45:34
  • TMDPHP 模板引擎使用教程

    2023-11-15 03:21:56
  • 如何用OpenCV -python3实现视频物体追踪

    2022-04-02 23:15:58
  • 通过遮罩层实现浮层DIV登录的js代码

    2024-06-24 00:08:58
  • VSCode的使用配置以及VSCode插件的安装教程详解

    2023-05-31 14:48:46
  • 如何保持Oracle数据库的优良性能

    2024-01-14 18:05:32
  • 浅谈MySQL安装starting the server失败的解决办法

    2024-01-25 06:37:22
  • 解决MybatisPlus SqlServer OFFSET 分页问题

    2024-01-12 16:26:24
  • 网页设计经验谈

    2007-10-30 13:11:00
  • asp之家 网络编程 m.aspxhome.com