pytorch加载自定义网络权重的实现
作者:wuming无名 时间:2022-06-16 14:39:10
在将自定义的网络权重加载到网络中时,报错:
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
我们一步一步分析。
模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')
(1)查看获取模型权重的源码:
pytorch源码:net.state_dict()
def state_dict(self, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> module.state_dict().keys()
['bias', 'weight']
"""
将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!
(2)查看保存模型权重的源码:
pytorch源码:torch.save()
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
See also: :ref:`recommend-saving-models`
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string
containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
.. warning::
If you are using Python 2, torch.save does NOT support StringIO.StringIO
as a valid file-like object. This is because the write method should return
the number of bytes written; StringIO.write() does not do this.
Please use something like io.BytesIO instead.
函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。
解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()
#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))
解决方法很简单,主要记录解决思路。
来源:https://blog.csdn.net/qq_34789262/article/details/83376374
标签:pytorch,加载,网络,权重
0
投稿
猜你喜欢
Python 实现任意区域文字识别(OCR)操作
2021-04-23 03:52:37
Python 控制终端输出文字的实例
2021-10-23 21:30:56
举例讲解Python中metaclass元类的创建与使用
2023-12-11 23:06:57
javascript跨域刷新实现代码
2024-04-16 08:46:37
使用字符串建立查询能加快服务器的解析速度吗?
2010-07-14 21:03:00
用VB编写ActiveX DLL实现ASP编程
2008-10-21 21:28:00
python3+django2开发一个简单的人员管理系统过程详解
2022-06-01 08:04:01
Mysql数据库错误代码中文详细说明
2024-01-16 09:55:21
对python:threading.Thread类的使用方法详解
2022-01-24 04:19:28
编程活动中几个不良现象
2008-09-01 12:23:00
Windows服务器MySQL中文乱码的解决方法
2024-01-12 16:46:51
在Django的通用视图中处理Context的方法
2023-02-25 20:50:45
掀起抛弃IE6的高潮吧
2009-02-26 12:44:00
解决Python中由于logging模块误用导致的内存泄露
2021-08-24 08:04:46
Python的Flask框架与数据库连接的教程
2024-01-24 14:43:55
python numpy实现rolling滚动案例
2023-08-24 17:12:45
简单谈谈Python中的反转字符串问题
2022-02-24 11:55:07
MySQL中Replace语句用法实例详解
2024-01-15 03:26:28
Linux下python制作名片示例
2022-06-07 00:29:33
Python实现自动玩贪吃蛇程序
2021-07-27 19:01:47