PyTorch预训练Bert模型的示例

作者:BLACK 时间:2021-11-12 14:31:39 

本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。

transformers模块简介

transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。

本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。

代码实现

首先安装数据下载模块和transformers包。


pip install datasets
pip install transformers

使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。


from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据

接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。


from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

对原始数据进行编码,并且分批次(batch)


def preprocessing_func(examples):
 return tokenizer(examples['text'],
          padding=True,
          truncation=True, max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。


from transformers import Trainer, TrainingArguments

args = TrainingArguments(
 'out',
 per_device_train_batch_size=batch_size,
 per_device_eval_batch_size=batch_size,
 learning_rate=5e-5,
 evaluation_strategy='epoch',
 num_train_epochs=10,
 load_best_model_at_end=True,
)

trainer = Trainer(
 model,
 args=args,
 train_dataset=encoded_data['train'],
 eval_dataset=encoded_data['test'],
 tokenizer=tokenizer
)

训练模型使用trainer对象的train方法


trainer.train()

PyTorch预训练Bert模型的示例

评估模型使用trainer对象的evaluate方法


trainer.evaluate()

总结

本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载

来源:http://www.blackedu.vip/729/pytorch-yu-xun-lianbert-mo-xing/?utm_source=tuicool&utm_medium=referral

标签:PyTorch,训练,Bert,模型
0
投稿

猜你喜欢

  • 在python里创建一个任务(Task)实例

    2023-09-12 23:24:16
  • Golang拾遗之自定义类型和方法集详解

    2024-03-19 03:08:29
  • 深入理解JSON数据源格式

    2024-05-10 14:06:09
  • JDBC连接mysql处理中文时乱码解决办法详解

    2024-01-17 21:35:35
  • 一起感受HTML5和CSS3的能量[译]

    2009-09-04 16:29:00
  • 推荐五个常用的python图像处理库

    2022-07-20 10:40:38
  • vue-element换肤所有主题色和基础色均可实现自主配置

    2024-04-28 09:29:00
  • oracle误删数据表还原的二种方法(oracle还原)

    2024-01-14 21:33:55
  • python字典多键值及重复键值的使用方法(详解)

    2023-03-18 14:23:15
  • 更改Ubuntu默认python版本的两种方法python-> Anaconda

    2021-07-30 15:33:28
  • 彻底删除thinkphp3.1案例blog标签的方法

    2023-11-21 12:01:01
  • 格式化数字ASP,PHP版

    2009-01-19 14:17:00
  • SQL server 管理事务和数据库介绍

    2024-01-21 18:54:32
  • 如何从Notes中读取数据?

    2009-11-15 19:57:00
  • python实现键盘输入的实操方法

    2022-07-25 19:18:25
  • Python栈的实现方法示例【列表、单链表】

    2023-07-20 15:51:42
  • php设计模式之适配器模式实例分析【星际争霸游戏案例】

    2024-05-11 09:55:05
  • Java生成日期时间存入Mysql数据库的实现方法

    2024-01-13 03:49:08
  • Python内存管理方式和垃圾回收算法解析

    2022-09-10 17:49:11
  • python数据结构之面向对象

    2021-04-09 08:02:06
  • asp之家 网络编程 m.aspxhome.com