pytorch_pretrained_bert如何将tensorflow模型转化为pytorch模型

作者:乐清sss 时间:2022-04-18 18:07:58 

pytorch_pretrained_bert将tensorflow模型转化为pytorch模型

BERT仓库里的模型是TensorFlow版本的,需要进行相应的转换才能在pytorch中使用

在Google BERT仓库里下载需要的模型,这里使用的是中文预训练模型(chinese_L-12_H-768_A_12)

pytorch_pretrained_bert如何将tensorflow模型转化为pytorch模型

下载chinese_L-12_H-768_A-12.zip后解压,里面有5个文件

chinese_L-12_H-768_A-12.zip后解压,里面有5个文件

bert_config.json

bert_model.ckpt.data-00000-of-00001

bert_model.ckpt.index

bert_model.ckpt.meta

vocab.txt

使用bert仓库里的convert_bert_original_tf_checkpoint_to_pytorch.py将此模型转化为pytorch版本的,这里我的文件夹位置为:D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12,替换为自己的即可

python convert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\bert_model.ckpt --bert_config_file D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\bert_config.json --pytorch_dump_path D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\pytorch_model.bin

注:这里让我疑惑的是模型有5个文件,为什么转化的时候使用的是bert_model.ckpt,而且这个文件也不存在呀,是我对TensorFlow的模型不太熟悉,查阅资料之后将5个文件的作用说明如下:


$ tree chinese_L-12_H-768_A-12/
chinese_L-12_H-768_A-12/
├── bert_config.json                     <- 模型配置文件
├── bert_model.ckpt.data-00000-of-00001  <- 保存断点文件列表,可以用来迅速查找最近一次的断点文件
├── bert_model.ckpt.index                <- 为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。
├── bert_model.ckpt.meta                 <- 是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等
└── vocab.txt                            <- 模型词汇表文件

0 directories, 5 files

在调用模型时使用chinese_L-12_H-768_A-12\bert_model.ckpt即可。

TensorFlow 读取ckpt文件中的tensor,将ckpt模型转为pytorch模型

想用MobileNet V1训练自己的数据,发现pytorch没有MobileNet V1的预训练权重,只好先下载TensorFlow的预训练权重,再转成pytorch模型。

读取ckpt中的Tensor名称以及Tensor值

TensorFlow的MobileNet V1预训练权重文件如下:

pytorch_pretrained_bert如何将tensorflow模型转化为pytorch模型

解压完文件后,发现没有.ckpt文件,文件名只需'./my_model/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt'这样写就行。

写一半发现Tensor名称好难对应起来。希望能给大家一个参考,也希望大家多多支持脚本之家

来源:https://blog.csdn.net/sunyueqinghit/article/details/103458365

标签:pytorch,pretrained,bert,tensorflow,pytorch
0
投稿

猜你喜欢

  • pyecharts调整图例与各板块的位置间距实例

    2023-05-15 20:05:40
  • 详解bootstrap导航栏.nav与.navbar区别

    2023-08-15 19:18:42
  • CSS Hack经验总结

    2008-05-01 13:13:00
  • Go语言文件开关及读写操作示例

    2023-08-05 19:47:27
  • 使用ACCESS做网络版程序的四种解决方案

    2009-01-14 16:22:00
  • Python如何调用JS文件中的函数

    2022-11-21 01:23:11
  • MySQL (root@%) does not exist的问题

    2011-03-16 15:31:00
  • Python通过TensorFLow进行线性模型训练原理与实现方法详解

    2022-11-10 16:17:27
  • 关于Internet Explorer 8

    2009-03-22 15:40:00
  • selenium+python自动化测试之环境搭建

    2022-05-15 13:51:32
  • Python 捕获代码中所有异常的方法

    2022-08-31 06:44:00
  • Ajax发明人:Ajax并不适合所有网站

    2008-01-30 12:20:00
  • python unicodedata模块用法

    2021-04-05 20:53:55
  • asp如何实现网上考试功能?

    2010-05-24 18:32:00
  • Python3读取文件常用方法实例分析

    2023-07-07 16:13:43
  • .NET中获取程序根目录的常用方法介绍

    2023-07-09 19:52:41
  • ASP 生成静态新闻列表

    2009-03-03 12:25:00
  • Flask使用Pyecharts在单个页面展示多个图表的方法

    2021-10-12 18:16:35
  • python opencv根据颜色进行目标检测的方法示例

    2021-09-29 03:53:41
  • Python制作数据导入导出工具

    2023-07-25 06:59:59
  • asp之家 网络编程 m.aspxhome.com