PyTorch预训练Bert模型的示例
作者:BLACK 时间:2021-11-12 14:31:39
本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。
transformers模块简介
transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。
本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。
代码实现
首先安装数据下载模块和transformers包。
pip install datasets
pip install transformers
使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。
from datasets import load_dataset
imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据
接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)
对原始数据进行编码,并且分批次(batch)
def preprocessing_func(examples):
return tokenizer(examples['text'],
padding=True,
truncation=True, max_length=300)
batch_size = 16
encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)
上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。
from transformers import Trainer, TrainingArguments
args = TrainingArguments(
'out',
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=5e-5,
evaluation_strategy='epoch',
num_train_epochs=10,
load_best_model_at_end=True,
)
trainer = Trainer(
model,
args=args,
train_dataset=encoded_data['train'],
eval_dataset=encoded_data['test'],
tokenizer=tokenizer
)
训练模型使用trainer对象的train方法
trainer.train()
评估模型使用trainer对象的evaluate方法
trainer.evaluate()
总结
本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载
来源:http://www.blackedu.vip/729/pytorch-yu-xun-lianbert-mo-xing/?utm_source=tuicool&utm_medium=referral
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
Python ftp上传文件
Python常用标准库详解(pickle序列化和JSON序列化)
Ajax+asp应用实例 注册模块,表单提交
asp 隐藏并修改文件的最后修改时间
![](https://img.aspxhome.com/file/UploadPic/20112/5/201125112319598.jpg)
Design IT. (3),看不懂数据
Matplotlib 3D 绘制小红花原理
![](https://img.aspxhome.com/file/2023/1/88001_0s.png)
设计模式学习笔记之 - 简单工厂模式
Python flask框架端口失效解决方案
![](https://img.aspxhome.com/file/2023/5/89275_0s.png)
python用opencv将标注提取画框到对应的图像中
![](https://img.aspxhome.com/file/2023/5/72845_0s.png)
Python语法快速入门指南
![](https://img.aspxhome.com/file/2023/0/72140_0s.jpg)
[翻译]标记语言和样式手册 Chapter 13 为文字指定样式
![](https://img.aspxhome.com/file/UploadPic/20082/15/2008215162639458s.jpg)
MenuEverywhere 程序图标设计
Python实现微博动态图片爬取详解
![](https://img.aspxhome.com/file/2023/0/61730_0s.jpg)
90行Python代码开发个人云盘应用
![](https://img.aspxhome.com/file/2023/3/68363_0s.gif)
Python编写memcached启动脚本代码实例
yolov5返回坐标的方法实例
![](https://img.aspxhome.com/file/2023/3/76723_0s.jpg)
python+splinter实现12306网站刷票并自动购票流程
页面制作的重要性
![](https://img.aspxhome.com/file/UploadPic/200710/30/2007103013184371s.gif)