beam search及pytorch的实现方式

作者:BruceWu1234 时间:2023-08-02 10:29:30 

主要记录两种不同的beam search版本

版本一

使用类似层次遍历的方式进行搜索,用队列进行维护,每次循环对当前层的所有节点进行搜索,这些节点每个分别对应topk个节点作为下一层候选节点,取所有候选节点的前tok个作为下一层节点加入队列

bfs with width constraint. 启发式搜索的一种. 属于贪心算法. 如果k -> inf,那么等价于bfs.

从根节点开始(),选取所有可能(大概几万个)里面概率最大的k个,拓展为下一层节点.

然后在这k个节点里面,其可能拓展的所有节点中(一般是k * 几万个),再选取概率最大的k个(注意这里的概率是累乘,即从根节点到该节点的概率乘积)拓展. 这里拓展的k个子节点,其父节点可以是上一层的k个,也可以只是其中一部分,甚至全部出自其中一个节点. 以此类推.

这样形成的是一棵每层都是k个节点树(除了根节点、末尾,和候选者不足k个的情况).

一般概率取log,避免值过小.

举个例子:k=2

<sos> 选取概率最大的三个, “i”: 0.6, “he”: 0.4. 其他单词忽略不计

拓展一共有4个 (1)“i"后面接,假设概率最大的是"love”: 0.7, “like”: 0.3 其他单词忽略不计(2)“he"后面接:假设概率最大的是"hates”: 0.9, “loves”: 0.1 其他单词忽略不计; 这样4种可能中,到这里 "i love"概率是0.6 * 0.7 = 0.42, "i like"概率是0.6 * 0.3 = 0.18, "he hates"概率是0.4 * 0.9 = 0.36, "he loves"概率是0.4 * 0.1 = 0.04; 选取概率最大的两个,“i love"和"he hates”.

下一层拓展仍为4个 (1) "i love"后面接 ,假设概率最大是 “you”:0.9, 其他单词加起来0.1;(2)“he hates"后面接,假设概率最大的是"her”:0.8, “himself”:0.1, 其他单词加起来0.1; 那么"i love you"概率为 0.42 * 0.9 = 0.378; "he hates her"概率为0.36*0.8 = 0.228,其他不用算了都小于这个值. 最后也选取2个概率最大的: "i love you"和 “he hates her”

下一层拓展, “i love you"应该拓展两个子节点,发现”"概率0.99,其他单词加起来0.01;“he hates her"应该拓展两个子节点,发现”"概率0.99,其他单词加起来0.01;所以概率最大的是"i love you "和"he hates you ". 因两个分支均遇到,均结束搜索.

最后在两个当中选择概率最大的 "i love you ". 结束

代码是从一个项目中截取的,只选取了关键内容,pytorch实现:


class Node(object):
   def __init__(self, hidden, previous_node, decoder_input, attn, log_prob, length):
       self.hidden = hidden
       self.previous_node = previous_node
       self.decoder_input = decoder_input
       self.attn = attn
       self.log_prob = log_prob
       self.length = length        
def beam_search(beam_width):
   ...
   root = Node(hidden, None, decoder_input, None, 0, 1)
   q = Queue()
   q.put(root)

end_nodes = [] #最终节点的位置,用于回溯
   while not q.empty():
       candidates = []  #每一层的可能被拓展的节点,只需选取每个父节点的儿子节点中概率最大的k个即可

for _ in range(q.qsize()):
           node = q.get()
           decoder_input = node.decoder_input
           hidden = node.hidden

# 搜索终止条件
           if decoder_input.item() == EOS or node.length >= 50:
               end_nodes.append(node)
               continue

log_prob, hidden, attn = decoder(
                decoder_input, hidden, encoder_input
            )

log_prob, indices = log_prob.topk(beam_width) #选取某个父节点的儿子节点概率最大的k个

for k in range(beam_width):
                 index = indices[k].unsqueeze(0)
                 log_p = log_prob[k].item()
                 child = Node(hidden, node, index, attn, node.log_prob + log_p, node.length + 1)
                 candidates.append((node.log_prob + log_p, child))  #建立候选儿子节点,注意这里概率需要累计

candidates = sorted(candidates, key=lambda x:x[0], reverse=True) #候选节点排序
        length = min(len(candidates), beam_width)  #取前k个,如果不足k个,则全部入选
        for i in range(length):
            q.put(candidates[i][1])  
   # 后面是回溯, 省略
   ...

版本二

不进行层次遍历,而是每次从整个队列中拿出概率最大的节点出队(优先队列)进行搜索,将该节点的topk加入优先队列,循环终止的条件是节点所在位置对应长度达到限制或队列节点个数超过限制


import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 0
EOS_token = 1
MAX_LENGTH = 50
class DecoderRNN(nn.Module):
   def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout=0.1):
       '''
       Illustrative decoder
       '''
       super(DecoderRNN, self).__init__()
       self.hidden_size = hidden_size
       self.cell_type = cell_type
       self.embedding = nn.Embedding(num_embeddings=output_size,
                                     embedding_dim=embedding_size,
                                     )
       self.rnn = nn.GRU(embedding_size, hidden_size, bidirectional=True, dropout=dropout, batch_first=False)
       self.dropout_rate = dropout
       self.out = nn.Linear(hidden_size, output_size)
   def forward(self, input, hidden, not_used):
       embedded = self.embedding(input).transpose(0, 1)  # [B,1] -> [ 1, B, D]
       embedded = F.dropout(embedded, self.dropout_rate)
       output = embedded
       # batch_first=False, output维度为 (seq_len, batch_size, num_directions * hidden_size) = [1, batch_size, 2*hidden_size]
       output, hidden = self.rnn(output, hidden)
       out = self.out(output.squeeze(0))
       # output维度为 [batch_size, vocab_size]
       # hidden维度为 [num_layers * num_directions, batch_size, hidden_size]
       output = F.log_softmax(out, dim=1)
       return output, hidden
class BeamSearchNode(object):
   def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
       '''
       :param hiddenstate:
       :param previousNode:
       :param wordId:
       :param logProb:
       :param length:
       '''
       self.h = hiddenstate
       self.prevNode = previousNode
       self.wordid = wordId
       self.logp = logProb
       self.leng = length
   def eval(self, alpha=1.0):
       reward = 0
       # Add here a function for shaping a reward
       return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
decoder = DecoderRNN()
def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None):
   '''
   :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
   :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
   :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
   :return: decoded_batch
   '''
   beam_width = 10
   topk = 1  # how many sentence do you want to generate
   decoded_batch = []
   # decoding goes sentence by sentence
   for idx in range(target_tensor.size(0)):
       if isinstance(decoder_hiddens, tuple):  # LSTM case
           decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0))
       else:
           decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0)
       encoder_output = encoder_outputs[:,idx, :].unsqueeze(1)
       # Start with the start of the sentence token
       decoder_input = torch.LongTensor([[SOS_token]], device=device)
       # Number of sentence to generate
       endnodes = []
       number_required = min((topk + 1), topk - len(endnodes))
       # starting node -  hidden vector, previous node, word id, logp, length
       node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
       nodes = PriorityQueue()
       # start the queue
       nodes.put((-node.eval(), node))
       qsize = 1
       # start beam search
       while True:
           # give up when decoding takes too long
           if qsize > 2000: break
           # fetch the best node
           score, n = nodes.get()
           decoder_input = n.wordid
           decoder_hidden = n.h
           if n.wordid.item() == EOS_token and n.prevNode != None:
               endnodes.append((score, n))
               # if we reached maximum # of sentences required
               if len(endnodes) >= number_required:
                   break
               else:
                   continue
           # output维度为 [batch_size, vocab_size]
           # hidden维度为 [num_layers * num_directions, batch_size, hidden_size]
           # decode for one step using decoder
           decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)
           # PUT HERE REAL BEAM SEARCH OF TOP
           # log_prov, indexes维度为 [batch_size, beam_width] = [1, beam_width]
           log_prob, indexes = torch.topk(decoder_output, beam_width, dim=1)
           nextnodes = []
           for new_k in range(beam_width):
               # decoded_t: [1,1],通过view(1,-1)将数字tensor变为维度为[1,1]的tensor
               decoded_t = indexes[0][new_k].view(1, -1)
               # log_p, int
               log_p = log_prob[0][new_k].item() # item()将tensor数字变为int
               node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
               score = -node.eval()
               nextnodes.append((score, node))
           # put them into queue
           for i in range(len(nextnodes)):
               score, nn = nextnodes[i]
               nodes.put((score, nn))
               # increase qsize
           qsize += len(nextnodes) - 1
       # choose nbest paths, back trace them
       if len(endnodes) == 0:
           endnodes = [nodes.get() for _ in range(topk)]
       utterances = []
       for score, n in sorted(endnodes, key=operator.itemgetter(0)):
           utterance = []
           utterance.append(n.wordid)
           # back trace
           while n.prevNode != None:
               n = n.prevNode
               utterance.append(n.wordid)
           utterance = utterance[::-1]
           utterances.append(utterance)
       decoded_batch.append(utterances)
   return decoded_batch
def greedy_decode(decoder_hidden, encoder_outputs, target_tensor):
   '''
   :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
   :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
   :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
   :return: decoded_batch
   '''
   batch_size, seq_len = target_tensor.size()
   decoded_batch = torch.zeros((batch_size, MAX_LENGTH))
   decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=device)
   for t in range(MAX_LENGTH):
       decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
       topv, topi = decoder_output.data.topk(1)  # get candidates
       topi = topi.view(-1)
       decoded_batch[:, t] = topi
       decoder_input = topi.detach().view(-1, 1)
   return decoded_batch

补充:beam search 简单例子实现及讲解

看代码吧~


from math import log
from numpy import array
from numpy import argmax
# beam search
def beam_search_decoder(data, k):
   sequences = [[list(), 1.0]]
   # walk over each step in sequence
   for row in data:
       all_candidates = list()
       # expand each current candidate
       for i in range(len(sequences)):
           seq, score = sequences[i]
           for j in range(len(row)):
               candidate = [seq + [j], score * -log(row[j])]
               all_candidates.append(candidate)
       # order all candidates by score
       ordered = sorted(all_candidates, key=lambda tup :tup[1])
       # select k best
       sequences = ordered[:k]
   return sequences
def greedy_decoder(data):
   # index for largest probability each row
   return [argmax(s) for s in data]
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
       [0.5, 0.4, 0.3, 0.2, 0.1],
       [0.1, 0.2, 0.3, 0.4, 0.5],
       [0.5, 0.4, 0.3, 0.2, 0.1],
       [0.1, 0.2, 0.3, 0.4, 0.5],
       [0.5, 0.4, 0.3, 0.2, 0.1],
       [0.1, 0.2, 0.3, 0.4, 0.5],
       [0.5, 0.4, 0.3, 0.2, 0.1],
       [0.1, 0.2, 0.3, 0.4, 0.5],
       [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
   print(seq)

每次循环sequences的值

[[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]]

[[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]]

[[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]]

最终print的结果

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]

[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]

来源:https://blog.csdn.net/u014514939/article/details/95667422

标签:beam,search,pytorch
0
投稿

猜你喜欢

  • 浅谈python 类方法/静态方法

    2021-02-23 07:00:54
  • CSS背景属性5个应用实例

    2009-09-13 20:54:00
  • MySQL死锁的产生原因以及解决方案

    2024-01-26 16:11:40
  • PHP操作数组的一些函数整理介绍

    2023-11-24 14:24:17
  • Python找出list中最常出现元素的方法

    2022-10-31 14:24:38
  • asp网站生成静态页面攻略

    2007-11-04 15:09:00
  • vue实现excel表格的导入导出的示例

    2024-05-10 14:10:22
  • python如何获得list或numpy数组中最大元素对应的索引

    2021-02-10 11:30:12
  • python机器学习使数据更鲜活的可视化工具Pandas_Alive

    2022-09-26 04:57:12
  • python实操案例练习(七)

    2021-09-23 01:01:03
  • SQL SERVER触发器详解

    2024-01-22 01:50:00
  • Asp无组件生成缩略图

    2007-10-26 12:08:00
  • 一文带你搞懂JS中导入模块import和require的区别

    2023-07-21 03:24:18
  • 浅谈ACCESS数据库升迁SQLSERVER注意事项

    2007-08-11 13:44:00
  • 浅谈Python中的字符串

    2022-10-05 00:39:22
  • Python机器学习从ResNet到DenseNet示例详解

    2023-04-21 09:19:16
  • python爬虫用mongodb的理由

    2023-09-27 23:06:40
  • python 处理数字,把大于上限的数字置零实现方法

    2022-11-13 09:20:56
  • 解析数据库分页的两种方法对比(row_number()over()和top的对比)

    2024-01-25 08:58:16
  • Django 使用VScode 创建工程的详细步骤

    2023-02-01 02:49:19
  • asp之家 网络编程 m.aspxhome.com