pytorch中的embedding词向量的使用方法

作者:乐且有仪 时间:2022-03-25 09:05:27 

Embedding

词嵌入在 pytorch 中非常简单,只需要调用 torch.nn.Embedding(m, n) 就可以了,m 表示单词的总数目,n 表示词嵌入的维度,其实词嵌入就相当于是一个大矩阵,矩阵的每一行表示一个单词。

emdedding初始化

默认是随机初始化的


import torch
from torch import nn
from torch.autograd import Variable
# 定义词嵌入
embeds = nn.Embedding(2, 5) # 2 个单词,维度 5
# 得到词嵌入矩阵,开始是随机初始化的
torch.manual_seed(1)
embeds.weight
# 输出结果:
Parameter containing:
-0.8923 -0.0583 -0.1955 -0.9656 0.4224
0.2673 -0.4212 -0.5107 -1.5727 -0.1232
[torch.FloatTensor of size 2x5]

如果从使用已经训练好的词向量,则采用


pretrained_weight = np.array(args.pretrained_weight) # 已有词向量的numpy
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))

embed的读取

读取一个向量。

注意参数只能是LongTensor型的


# 访问第 50 个词的词向量
embeds = nn.Embedding(100, 10)
embeds(Variable(torch.LongTensor([50])))
# 输出:
Variable containing:
0.6353 1.0526 1.2452 -1.8745 -0.1069 0.1979 0.4298 -0.3652 -0.7078 0.2642
[torch.FloatTensor of size 1x10]

读取多个向量。

输入为两个维度(batch的大小,每个batch的单词个数),输出则在两个维度上加上词向量的大小。


Input: LongTensor (N, W), N = mini-batch, W = number of indices to extract per mini-batch
Output: (N, W, embedding_dim)

见代码


# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
# 每批取两组,每组四个单词
input = Variable(torch.LongTensor([[1,2,4,5],[4,3,2,9]]))
a = embedding(input) # 输出2*4*3
a[0],a[1]

输出为:


(Variable containing:
-1.2603 0.4337 0.4181
0.4458 -0.1987 0.4971
-0.5783 1.3640 0.7588
0.4956 -0.2379 -0.7678
[torch.FloatTensor of size 4x3], Variable containing:
-0.5783 1.3640 0.7588
-0.5313 -0.3886 -0.6110
0.4458 -0.1987 0.4971
-1.3768 1.7323 0.4816
[torch.FloatTensor of size 4x3])

来源:https://blog.csdn.net/david0611/article/details/81090371

标签:pytorch,embedding,词向量
0
投稿

猜你喜欢

  • 解读HTML:命名空间与字符编码

    2008-12-10 14:03:00
  • bootstrap响应式工具使用详解

    2023-08-07 18:14:35
  • Iinternet Explorer浏览器简介(IE)

    2009-02-05 20:59:00
  • 浅谈ASP自动采集程序及入库

    2007-08-17 11:25:00
  • 将Reporting services的RDL文件拷贝到另外一台机器时报Data at the root level is invalid的解决方法

    2012-07-11 15:33:45
  • javascript封装的下拉导航菜单渐显效果

    2007-08-04 20:11:00
  • JS事件在IE与FF中的区别详细解析

    2023-09-24 23:02:35
  • JS载入数据效果!loading

    2009-01-20 18:35:00
  • ASP设计常见问题及解答精要

    2009-04-21 11:16:00
  • 利用WSH获取计算机硬件信息、DNS信息等

    2008-05-05 13:04:00
  • Python线性网络实现分类糖尿病病例

    2022-03-13 11:23:25
  • SQL Server 2000 作数据库服务器的优点

    2009-01-23 13:47:00
  • 编码问题引起的折腾

    2009-07-03 12:43:00
  • 静态页面利用JS读取cookies记住用户信息

    2011-04-14 11:17:00
  • Python使用Matplotlib绘制三维散点图详解流程

    2023-09-17 13:36:59
  • 关于浏览器的一些观点

    2008-08-06 12:48:00
  • python2 与 python3 实现共存的方法

    2023-06-13 23:56:29
  • SQL Server 2005 输入框不能输入中文问题

    2010-02-04 09:14:00
  • 如何实现文本的卷屏浏览?

    2010-05-24 18:36:00
  • MySQL数据库配置技巧

    2009-03-06 14:32:00
  • asp之家 网络编程 m.aspxhome.com