pytorch中的embedding词向量的使用方法
作者:乐且有仪 时间:2022-03-25 09:05:27
Embedding
词嵌入在 pytorch 中非常简单,只需要调用 torch.nn.Embedding(m, n) 就可以了,m 表示单词的总数目,n 表示词嵌入的维度,其实词嵌入就相当于是一个大矩阵,矩阵的每一行表示一个单词。
emdedding初始化
默认是随机初始化的
import torch
from torch import nn
from torch.autograd import Variable
# 定义词嵌入
embeds = nn.Embedding(2, 5) # 2 个单词,维度 5
# 得到词嵌入矩阵,开始是随机初始化的
torch.manual_seed(1)
embeds.weight
# 输出结果:
Parameter containing:
-0.8923 -0.0583 -0.1955 -0.9656 0.4224
0.2673 -0.4212 -0.5107 -1.5727 -0.1232
[torch.FloatTensor of size 2x5]
如果从使用已经训练好的词向量,则采用
pretrained_weight = np.array(args.pretrained_weight) # 已有词向量的numpy
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))
embed的读取
读取一个向量。
注意参数只能是LongTensor型的
# 访问第 50 个词的词向量
embeds = nn.Embedding(100, 10)
embeds(Variable(torch.LongTensor([50])))
# 输出:
Variable containing:
0.6353 1.0526 1.2452 -1.8745 -0.1069 0.1979 0.4298 -0.3652 -0.7078 0.2642
[torch.FloatTensor of size 1x10]
读取多个向量。
输入为两个维度(batch的大小,每个batch的单词个数),输出则在两个维度上加上词向量的大小。
Input: LongTensor (N, W), N = mini-batch, W = number of indices to extract per mini-batch
Output: (N, W, embedding_dim)
见代码
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
# 每批取两组,每组四个单词
input = Variable(torch.LongTensor([[1,2,4,5],[4,3,2,9]]))
a = embedding(input) # 输出2*4*3
a[0],a[1]
输出为:
(Variable containing:
-1.2603 0.4337 0.4181
0.4458 -0.1987 0.4971
-0.5783 1.3640 0.7588
0.4956 -0.2379 -0.7678
[torch.FloatTensor of size 4x3], Variable containing:
-0.5783 1.3640 0.7588
-0.5313 -0.3886 -0.6110
0.4458 -0.1987 0.4971
-1.3768 1.7323 0.4816
[torch.FloatTensor of size 4x3])
来源:https://blog.csdn.net/david0611/article/details/81090371
标签:pytorch,embedding,词向量
0
投稿
猜你喜欢
Pytorch转onnx、torchscript方式
2022-05-03 11:10:43
python转化excel数字日期为标准日期操作
2021-01-14 22:38:59
Django中信号signals的简单使用方法
2023-08-18 08:49:49
ASP实现表单中容量大的数据的提交方法
2008-10-16 11:07:00
Python高级用法总结
2021-04-20 13:03:01
一文教会你pandas plot各种绘图
2021-04-29 19:41:11
Python内置加密模块用法解析
2021-09-17 02:26:19
OpenCV实现直线检测
2023-08-14 01:37:35
Python+OpenCV实现阈值分割的方法详解
2023-08-13 02:24:00
10个Python面试常问的问题(小结)
2023-04-11 19:36:15
python实现基于两张图片生成圆角图标效果的方法
2023-04-20 17:58:56
git版本库介绍及本地创建的三种场景方式
2023-07-11 11:22:18
一文速学Python+Pyecharts绘制树形图
2023-07-28 12:05:27
Node.js基础模块babel使用详解
2024-05-13 09:35:11
学习CSS布局心得
2007-05-11 16:50:00
python实现八大排序算法(2)
2023-09-05 06:28:23
Golang实现HTTP编程请求和响应
2024-04-28 09:10:42
Django模板报TemplateDoesNotExist异常(亲测可行)
2023-11-02 18:53:49
Python机器学习pytorch交叉熵损失函数的深刻理解
2021-12-11 06:09:40
MySQL数据库之Purge死锁问题解析
2024-01-28 05:11:50