pytorch加载自定义网络权重的实现

作者:wuming无名 时间:2022-06-16 14:39:10 

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()


def state_dict(self, destination=None, prefix='', keep_vars=False):
 r"""Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are
 included. Keys are corresponding parameter and buffer names.

Returns:
   dict:
     a dictionary containing a whole state of the module

Example::

>>> module.state_dict().keys()
   ['bias', 'weight']

"""

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()


def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
 """Saves an object to a disk file.

See also: :ref:`recommend-saving-models`

Args:
   obj: saved object
   f: a file-like object (has to implement write and flush) or a string
     containing a file name
   pickle_module: module used for pickling metadata and objects
   pickle_protocol: can be specified to override the default protocol

.. warning::
   If you are using Python 2, torch.save does NOT support StringIO.StringIO
   as a valid file-like object. This is because the write method should return
   the number of bytes written; StringIO.write() does not do this.

Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()


#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

来源:https://blog.csdn.net/qq_34789262/article/details/83376374

标签:pytorch,加载,网络,权重
0
投稿

猜你喜欢

  • Python 实现任意区域文字识别(OCR)操作

    2021-04-23 03:52:37
  • Python 控制终端输出文字的实例

    2021-10-23 21:30:56
  • 举例讲解Python中metaclass元类的创建与使用

    2023-12-11 23:06:57
  • javascript跨域刷新实现代码

    2024-04-16 08:46:37
  • 使用字符串建立查询能加快服务器的解析速度吗?

    2010-07-14 21:03:00
  • 用VB编写ActiveX DLL实现ASP编程

    2008-10-21 21:28:00
  • python3+django2开发一个简单的人员管理系统过程详解

    2022-06-01 08:04:01
  • Mysql数据库错误代码中文详细说明

    2024-01-16 09:55:21
  • 对python:threading.Thread类的使用方法详解

    2022-01-24 04:19:28
  • 编程活动中几个不良现象

    2008-09-01 12:23:00
  • Windows服务器MySQL中文乱码的解决方法

    2024-01-12 16:46:51
  • 在Django的通用视图中处理Context的方法

    2023-02-25 20:50:45
  • 掀起抛弃IE6的高潮吧

    2009-02-26 12:44:00
  • 解决Python中由于logging模块误用导致的内存泄露

    2021-08-24 08:04:46
  • Python的Flask框架与数据库连接的教程

    2024-01-24 14:43:55
  • python numpy实现rolling滚动案例

    2023-08-24 17:12:45
  • 简单谈谈Python中的反转字符串问题

    2022-02-24 11:55:07
  • MySQL中Replace语句用法实例详解

    2024-01-15 03:26:28
  • Linux下python制作名片示例

    2022-06-07 00:29:33
  • Python实现自动玩贪吃蛇程序

    2021-07-27 19:01:47
  • asp之家 网络编程 m.aspxhome.com