pytorch加载自定义网络权重的实现

作者:wuming无名 时间:2022-06-16 14:39:10 

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()


def state_dict(self, destination=None, prefix='', keep_vars=False):
 r"""Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are
 included. Keys are corresponding parameter and buffer names.

Returns:
   dict:
     a dictionary containing a whole state of the module

Example::

>>> module.state_dict().keys()
   ['bias', 'weight']

"""

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()


def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
 """Saves an object to a disk file.

See also: :ref:`recommend-saving-models`

Args:
   obj: saved object
   f: a file-like object (has to implement write and flush) or a string
     containing a file name
   pickle_module: module used for pickling metadata and objects
   pickle_protocol: can be specified to override the default protocol

.. warning::
   If you are using Python 2, torch.save does NOT support StringIO.StringIO
   as a valid file-like object. This is because the write method should return
   the number of bytes written; StringIO.write() does not do this.

Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()


#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

来源:https://blog.csdn.net/qq_34789262/article/details/83376374

标签:pytorch,加载,网络,权重
0
投稿

猜你喜欢

  • Python使用minidom读写xml的方法

    2022-03-14 11:35:22
  • 如何在scrapy中捕获并处理各种异常

    2023-04-10 06:56:23
  • Python操作sqlite3快速、安全插入数据(防注入)的实例

    2022-04-22 16:38:14
  • Python调用百度AI实现颜值评分功能

    2023-07-30 22:53:40
  • python如何实现质数求和

    2023-03-02 20:17:24
  • Python Flask微信小程序登录流程及登录api实现代码

    2022-03-21 14:33:47
  • python绘制简单彩虹图

    2022-09-06 04:55:14
  • php测试程序运行速度和页面执行速度的代码

    2023-06-14 07:49:18
  • ASP UTF-8编码下字符串截取和获取长度函数

    2011-03-30 10:52:00
  • Python 对象中的数据类型

    2022-01-25 00:58:35
  • python开发中两个list之间传值示例

    2022-06-07 03:44:54
  • 详细解读Python中的__init__()方法

    2023-03-25 17:10:27
  • Gradio机器学习模型快速部署工具quickstart前篇

    2023-07-01 15:07:51
  • 使用pandas的DataFrame的plot方法绘制图像的实例

    2023-07-02 08:33:52
  • Flask框架利用Echarts实现绘制图形

    2023-01-08 11:52:42
  • Pytorch dataloader在加载最后一个batch时卡死的解决

    2022-09-15 06:50:34
  • MySQL出现1067错误如何解决?

    2008-09-03 12:25:00
  • 关于Python中进度条的六个实用技巧分享

    2023-07-03 09:58:39
  • 关于多元线性回归分析——Python&SPSS

    2023-03-11 17:03:34
  • Python学习笔记之读取文件、OS模块、异常处理、with as语法示例

    2023-03-20 21:54:58
  • asp之家 网络编程 m.aspxhome.com