pytorch模型保存与加载中的一些问题实战记录

作者:colourmind 时间:2021-09-03 21:41:50 

前言

最近使用pytorch训练模型,保存模型后再次加载使用出现了一些问题。记录一下解决方案!

一、torch中模型保存和加载的方式

1、模型参数和模型结构保存和加载

torch.save(model,path)
torch.load(path)

2、只保存模型的参数和加载——这种方式比较安全,但是比较稍微麻烦一点点

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)

二、torch中模型保存和加载出现的问题

1、单卡模型下保存模型结构和参数后加载出现的问题

模型保存的时候会把模型结构定义文件路径记录下来,加载的时候就会根据路径解析它然后装载参数;当把模型定义文件路径修改以后,使用torch.load(path)就会报错。

pytorch模型保存与加载中的一些问题实战记录

pytorch模型保存与加载中的一些问题实战记录

pytorch模型保存与加载中的一些问题实战记录

把model文件夹修改为models后,再加载就会报错。

import torch
from model.TextRNN import TextRNN

load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)

这种保存完整模型结构和参数的方式,一定不要改动模型定义文件路径

2、多卡机器单卡训练模型保存后在单卡机器上加载会报错

在多卡机器上有多张显卡0号开始,现在模型在n>=1上的显卡训练保存后,拷贝在单卡机器上加载

import torch
from model.TextRNN import TextRNN

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)

pytorch模型保存与加载中的一些问题实战记录

会出现cuda device不匹配的问题——你保存的模代码段 小部件型是使用的cuda1,那么采用torch.load()打开的时候,会默认的去寻找cuda1,然后把模型加载到该设备上。这个时候可以直接使用map_location来解决,把模型加载到CPU上即可。

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))

3、多卡训练模型保存模型结构和参数后加载出现的问题

当用多GPU同时训练模型之后,不管是采用模型结构和参数一起保存还是单独保存模型参数,然后在单卡下加载都会出现问题

a、模型结构和参数一起保然后在加载

pytorch模型保存与加载中的一些问题实战记录

torch.distributed.init_process_group(backend='nccl')

模型训练的时候采用上述多进程的方式,所以你在加载的时候也要声明,不然就会报错。

b、单独保存模型参数

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)

同样会出现问题,不过这里出现的问题是参数字典的key和模型定义的key不一样

pytorch模型保存与加载中的一些问题实战记录

原因是多GPU训练下,使用分布式训练的时候会给模型进行一个包装,代码如下:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)

包装前的模型结构:

pytorch模型保存与加载中的一些问题实战记录

包装后的模型

pytorch模型保存与加载中的一些问题实战记录

在外层多了DistributedDataParallel以及module,所以才会导致在单卡环境下加载模型权重的时候出现权重的keys不一致。

三、正确的保存模型和加载的方法

if gpu_count > 1:
       torch.save(model.module.state_dict(),save_path)
   else:
       torch.save(model.state_dict(),save_path)
   model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
   state_dict = torch.load(save_path)
   model.load_state_dict(state_dict)

这样就是比较好的范式,加载不会出错。

来源:https://blog.csdn.net/HUSTHY/article/details/115199280

标签:pytorch,模型,加载
0
投稿

猜你喜欢

  • MySQL8.0.32的安装与配置超详细图文教程

    2024-01-17 11:24:42
  • 详解如何用SQLyog来分析MySQL数据库

    2008-10-13 12:35:00
  • Python学习之函数 def

    2022-09-06 09:57:41
  • 关于golang监听rabbitmq消息队列任务断线自动重连接的问题

    2024-04-25 13:21:03
  • sqlserver数据库最大Id冲突问题解决方法之一

    2024-01-28 01:48:06
  • Python进程间通讯与进程池超详细讲解

    2023-09-05 16:50:41
  • js打开新窗口方法整理

    2024-04-10 16:13:05
  • 《CSS权威指南》文摘(1)--块级元素、行内元素

    2008-04-05 13:42:00
  • 从Web查询数据库之PHP与MySQL篇

    2009-09-19 16:58:00
  • python处理文本文件并生成指定格式的文件

    2021-05-16 15:43:42
  • 教你使用Python画棵圣诞树完整代码

    2022-04-12 11:29:57
  • asp日期转换成汉字格式程序

    2008-07-08 18:19:00
  • python IDLE添加行号显示教程

    2022-03-30 18:55:52
  • 如何利用python将Xmind用例转为Excel用例

    2022-06-18 19:18:46
  • mysql8.0主从复制搭建与配置方案

    2024-01-15 11:26:25
  • X/HTML5 v.s. XHTML2(I)

    2008-06-17 18:00:00
  • python整小时 整天时间戳获取算法示例

    2021-02-11 10:27:33
  • python好玩的项目—色情图片识别代码分享

    2022-01-26 03:19:19
  • python中的集合及集合常用的使用方法

    2023-05-04 11:14:11
  • vue.js如何在网页中实现一个金属抛光质感的按钮

    2024-04-28 09:21:26
  • asp之家 网络编程 m.aspxhome.com