Beam Search快速理解及代码解析
阅读原文时间:2021年11月02日阅读:2

目录

Beam Search快速理解及代码解析(上)

简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索)。

生成式任务相比普通的分类、tagging等NLP任务会复杂不少。在生成的时候,模型的输出是一个时间步一个时间步依次获得的,而且前面时间步的结果还会影响后面时间步的结果。也就是说,每一个时间步,模型给出的都是基于历史生成结果的条件概率。为了生成完整的句子,需要一个称为解码的额外动作来融合模型多个时间步的输出,而且使得最终得到的序列的每一步条件概率连乘起来最大。

在文本生成任务中,每一个时间步可能的输出种类称为字典大小(vocabulary size,我们用V表示),进行T步随机的生成可能获得的结果总共有V^T种。拿中文文本生成来说,V 的值大约是5000-6000,即常用汉字的个数。在如此大的基数下,遍历整个生成空间是不现实的。

贪心搜索

每一个时间步都取出一个条件概率最大的输出,如图:

思路也很简单,就是稍微放宽一些考察的范围。在每一个时间步,不再只保留当前分数最高的1个输出,而是保留num_beams个。当num_beams=1时集束搜索就退化成了贪心搜索。

Beam Search示意图

  • 在第一个时间步,A和C是最优的两个,因此得到了两个结果[A],[C],其他三个就被抛弃了;
  • 第二步会基于这两个结果继续进行生成,在A这个分支可以得到5个候选人,[AA],[AB],[AC],[AD],[AE],C也同理得到5个,此时会对这10个进行统一排名,再保留最优的两个,即图中的[AB]和[CE];
  • 第三步同理,也会从新的10个候选人里再保留最好的两个,最后得到了[ABD],[CED]两个结果。 可以发现,beam search在每一步需要考察的候选人数量是贪心搜索的num_beams倍,因此是一种牺牲时间换性能的方法。

Beam Search的原理虽然简单,但实际实现的时候却有很多细节要考虑。下面要解析这个实现出自于NLP界著名Python包Transformers[1],我为了说明方便做了一些改动。

一个正确且高效的算法需要处理的问题大概有两个:

  • 充分利用硬件,可以处理批量数据,且尽量使用并行计算少用循环
  • 处理好长短不同的生成结果

下面是基础版的beam search函数定义。其中context是编码器编码获得的向量,batch_size是每批数据中包含的样本量,bos_token_id是句子开头标志的token id,pad_token_id是用于填充的token id,eos_token_id是句子结束标志的token id。这里给参数填上的默认值和我们后面讲解时使用的例子是一致的。

def beam_search_generate(context,
                        batch_size=3,
                        max_length=20,
                        min_length=2,
                        num_beams=2,
                        bos_token_id=101,
                        pad_token_id=0,
                        eos_token_id=102,
                        ):
    pass

在函数中主要执行以下三个步骤:

  • 准备初始输入
  • 在当前生成的序列长度未达到max_length时扩展生成序列
  • 准备最终输出的序列

准备初始输入

# 建立beam容器,每个样本一个
generated_hyps = [
    BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
    for _ in range(batch_size)
]

# 每个beam容器的得分,共batch_size*num_beams个
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=encoder_input_ids.device)
beam_scores = beam_scores.view(-1)

# 每个样本是否完成生成,共batch_size个
done = [False for _ in range(batch_size)]

# 为了并行计算,一次生成batch_size*num_beams个序列
# 第一步自动填入bos_token
input_ids = torch.full(
    (batch_size*num_beams, 1),
    bos_token_id,
    dtype=torch.long,
    device=next(self.parameters()).device,
)

# 当前长度设为1
cur_len = 1

其中BeamHypotheses是一个容器类,每个样本绑定一个。每个容器中会维护num_beams个当前最优的序列。当往容器中添加一个序列而导致序列数大于num_beams的时候,它会自动踢掉分数最低的那个序列。类代码如下。

class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty):
        self.max_length = max_length - 1   # ignoring bos_token
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = 1e9

    def __len__(self):
        return len(self.beams)

    def add(self, hyp, sum_logprobs):
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.num_beams or score > self.worst_score:
            # 可更新的情况:数量未饱和或超过最差得分
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                # 数量饱和需要删掉一个最差的
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len=None):
        """
        相关样本是否已经完成生成。
        best_sum_logprobs是新的候选序列中的最高得分。
        """

        if len(self) < self.num_beams:
            return False
        else:
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            # 是否最高分比当前保存的最低分还差
            ret = self.worst_score >= cur_score
            return ret

序列扩展

序列扩展是beam search的核心过程,我们特地画了一张图来解释这个版本的实现策略。

序列扩展示意图,下面对照这个图来讲解代码。

while cur_len < max_length:
    # 将编码器得到的上下文向量和当前结果输入解码器,即图中1
    output = decoder.decode_next_step(context, input_ids)
    # 输出矩阵维度为:(batch*num_beams)*cur_len*vocab_size

    # 取出最后一个时间步的各token概率,即当前条件概率
    # (batch*num_beams)*vocab_size
    scores = next_token_logits = output[:, -1, :]

    ###########################
    # 这里可以做一大堆操作减少重复 #
    ###########################

    # 计算序列条件概率的,因为取了log,所以直接相加即可。得到图中2矩阵
    # (batch_size * num_beams, vocab_size)
    next_scores = scores + beam_scores[:, None].expand_as(scores)

    # 为了提速,将结果重排成图中3的形状
    next_scores = next_scores.view(
            batch_size, num_beams * vocab_size
        )  # (batch_size, num_beams * vocab_size)

    # 取出分数最高的token(图中黑点)和其对应得分
    # sorted=True,保证返回序列是有序的
    next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)

    # 下一个时间步整个batch的beam列表
    # 列表中的每一个元素都是三元组
    # (分数, token_id, beam_id)
    next_batch_beam = []

    # 对每一个样本进行扩展
    for batch_idx in range(batch_size):

        # 检查样本是否已经生成结束
        if done[batch_idx]:
            # 对于已经结束的句子,待添加的是pad token
            next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
            continue

        # 当前样本下一个时间步的beam列表
        next_sent_beam = []

        # 对于还未结束的样本需要找到分数最高的num_beams个扩展
        # 注意,next_scores和next_tokens是对应的
        # 而且已经按照next_scores排好顺序
        for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
            zip(next_tokens[batch_idx], next_scores[batch_idx])
        ):
            # get beam and word IDs
            # 这两行可参考图中3进行理解
            beam_id = beam_token_id // vocab_size
            token_id = beam_token_id % vocab_size

            effective_beam_id = batch_idx * num_beams + beam_id

            # 如果出现了EOS token说明已经生成了完整句子
            if (eos_token_id is not None) and (token_id.item() == eos_token_id):
                # if beam_token does not belong to top num_beams tokens, it should not be added
                is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                if is_beam_token_worse_than_top_num_beams:
                    continue
                # 往容器中添加这个序列
                generated_hyps[batch_idx].add(
                    input_ids[effective_beam_id].clone(), beam_token_score.item(),
                )
            else:
                # add next predicted word if it is not eos_token
                next_sent_beam.append((beam_token_score, token_id, effective_beam_id))

            # 扩展num_beams个就够了
            if len(next_sent_beam) == num_beams:
                break

        # 检查这个样本是否已经生成完了,有两种情况
        # 1. 已经记录过该样本结束
        # 2. 新的结果没有使结果改善
        done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
            next_scores[batch_idx].max().item(), cur_len=cur_len
        )

        # 把当前样本的结果添加到batch结果的后面
        next_batch_beam.extend(next_sent_beam)

    # 如果全部样本都已经生成结束便可以直接退出了
    if all(done):
        break

    # 把三元组列表再还原成三个独立列表
    beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
    beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
    beam_idx = input_ids.new([x[2] for x in next_batch_beam])

    # 准备下一时刻的解码器输入
    # 取出实际被扩展的beam
    input_ids = input_ids[beam_idx, :]
    # 在这些beam后面接上新生成的token
    input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)

    # 更新当前长度
    cur_len = cur_len + 1
    # end of length while

准备输出

上面那个while循环跳出意味着已经生成了长度为max_length的文本,比较理想的情况是所有的句子都已经生成出了eos_token_id,即句子生成结束了。但并不是所有情况都这样,对于那些”意犹未尽“的样本,我们需要先手动结束。

# 将未结束的生成结果结束,并置入容器中
for batch_idx in range(batch_size):
    # 已经结束的样本不需处理
    if done[batch_idx]:
        continue

    # 把结果加入到generated_hyps容器
    for beam_id in range(num_beams):
        effective_beam_id = batch_idx * num_beams + beam_id
        final_score = beam_scores[effective_beam_id].item()
        final_tokens = input_ids[effective_beam_id]
        generated_hyps[batch_idx].add(final_tokens,final_score)

经过上面的处理,所有生成好的句子都已经保存在generated_hyps容器中,每个容器内保存着num_beams个序列,最后就是输出期望个数的句子。

# select the best hypotheses,最终输出
# 每个样本返回几个句子
output_num_return_sequences_per_batch = 1
# 记录每个返回句子的长度,用于后面pad
sent_lengths = input_ids.new(output_batch_size)
best = []

# 对每个样本取出最好的output_num_return_sequences_per_batch个句子
for i, hypotheses in enumerate(generated_hyps):
    sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
    for j in range(output_num_return_sequences_per_batch):
        effective_batch_idx = output_num_return_sequences_per_batch * i + j
        best_hyp = sorted_hyps.pop()[1]
        sent_lengths[effective_batch_idx] = len(best_hyp)
        best.append(best_hyp)

# 如果长短不一则pad句子,使得最后返回结果的长度一样
if sent_lengths.min().item() != sent_lengths.max().item():
    sent_max_len = min(sent_lengths.max().item() + 1, max_length)
    # 先把输出矩阵填满PAD token
    decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)

    # 填入真正的内容
    for i, hypo in enumerate(best):
        decoded[i, : sent_lengths[i]] = hypo
        # 填上eos token
        if sent_lengths[i] < max_length:
            decoded[i, sent_lengths[i]] = eos_token_id
else:
    # 所有生成序列都还没结束,直接堆叠即可
    decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)

# 返回的结果包含BOS token
return decoded

好了,上面就是最基础的beam search算法。这样生成出来的结果已经会比贪心搜索好一些,但还是会遇到诸如词语重复这样的问题。其实已经有很多针对重复问题的研究,还有下篇。

Beam Search快速理解及代码解析(下)

先解释一下什么要对Beam Search进行改进。因为Beam Search虽然比贪心强了不少,但还是会生成出空洞、重复、前后矛盾的文本。如果你有文本生成经验,一定对这些现象并不陌生。在语言模型还不像如今的BERT、GPT这么厉害的时候,这种现象更加明显。

没有经验也没关系,我们来看一个论文里面的例子。输入模型的引文(context)

"The study, published in the Proceedings of the They were cattle called Bolivian Cavalleros; they live in a National Academy of Sciences of the United States of remote desert uninterrupted by town, and they speak huge, America (PNAS), was conducted by researchers from the beautiful, paradisiacal Bolivian linguistic thing. They say, Universidad Nacional Autónoma de México (UNAM) and

GPT-2模型, Beam Search, num_beams=32的生成结果:

'Lunch, marge.' They don't tell what the lunch is," director the Universidad Nacional Autónoma de México Professor Chuperas Omwell told Sky News. "They've only (UNAM/Universidad Nacional Autónoma de been talking to scientists, like we're being interviewed by TV México/Universidad Nacional Autónoma de reporters. We don't even stick around to be interviewed by México/Universidad Nacional Autónoma de TV reporters. Maybe that's how they figured out that they're México/Universidad Nacional Autónoma de …”

可以发现即使是如今最顶级的语言模型加上足够长的引文输入,还是无法得到高质量的生成结果。

论文认为这种问题是由于这种试图最大化序列条件概率的解码策略从根上就有问题。他们对比了给定同样引文的情况下人类续写和机器生成的词用语言模型计算出来的概率。如下图所示,人类选择的词(橙线)并不是像机器选择的(蓝线)那样总是那些条件概率最大的词。从生成的结果也可以看出,机器生成的结果有大量重复。

机器选词和人类选词的概率对比图

人们其实尝试了各种办法对Beam Search进行改进,其实都很好理解,这篇论文总结的也比较到位。

随机采样

第一种方法是用随机采样(sampling)代替取概率最大的词。采样的依据就是解码器输出的词典中每个词的概率分布。相比于按概率“掐尖”,这样会增大所选词的范围,引入更多的随机性。当时那篇论文的结论就是这种随机采样的方法远好于Beam Search。但这其实也是有条件的,随机采样容易产生前后不一致的问题。而在开放闲聊领域,生成文本的长度都比较短,这种问题就被自然的淡化了。

采样的时候有一个可以控制的超参数,称为温度(temperature, )。解码器的输出层后面通常会跟一个softmax函数来将输出概率归一化,通过改变 可以控制概率分布的形貌。softmax的公式如下,当 大的时候,概率分布趋向平均,随机性增大;当 小的时候,概率密度趋向于集中,即强者愈强,随机性降低,会更多地采样出“放之四海而皆准”的词汇。

top-k采样

这个方法就是在采样前将输出的概率分布截断,取出概率最大的k个词构成一个集合,然后将这个子集词的概率再归一化,最后从新的概率分布中采样词汇。这个办法据说可以获得比Beam Search好很多的效果,但也有一个问题,就是这个k不太好选。

While top-k sampling leads to considerably higher quality text than either beam search or sampling from the full distribution, the use of a constant k is sub-optimal across varying contexts.

为啥呢?因为这个概率分布变化比较大,有时候可能很均匀(flat),有的时候比较集中(peaked)。对于集中的情况还好说,当分布均匀时,一个较小的k容易丢掉很多优质候选词。但如果k定的太大,这个方法又会退化回普通采样。

两种分布,左边是均匀的,右边是集中的

核采样(Nucleus sampling)

首先表示我不确定这个翻译是不是对的。

这是这篇论文提出的方式,也是相比前面那些都更好的采样方式,这个方法不再取一个固定的k,而是固定候选集合的概率密度和在整个概率分布中的比例。也就是构造一个最小候选集V ,使得

选出来这个集合之后也和top-k采样一样,重新归一化集合内词的概率,并把集合外词的概率设为0。这种方式也称为top-p采样。

论文有一个图,对比了这几种采样方式的效果。

效果对比图,红字是前后不符,蓝字是重复。Nucleus效果拔群。

惩罚重复

为了解决重复问题,还可以通过惩罚因子将出现过词的概率变小或者强制不使用重复词来解决。惩罚因子来自于同样广为流传的《CTRL: A Conditional Transformer Language Model for Controllable Generation》[2]。如果大家感兴趣的话后面可以专门写一期可控文本生成方向的解读。

其实上述各种采样方式在HuggingFace的库里都已经实现了(感动!),我们来看一下代码。

先看top-k和top-p采样

 1 # 代码输入的是logits,而且考虑很周全(我感觉漏了考虑k和p都给了的情况,这应该是不合适的)
 2 # 巧妙地使用了torch.cumsum
 3 # 避免了一个词都选不出来的尴尬情况
 4 def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
 5     """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
 6         Args:
 7             logits: logits distribution shape (batch size, vocabulary size)
 8             if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
 9             if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
10                 Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
11             Make sure we keep at least min_tokens_to_keep per batch example in the output
12         From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
13     """
14     if top_k > 0:
15         top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
16         # Remove all tokens with a probability less than the last token of the top-k
17         indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
18         logits[indices_to_remove] = filter_value
19
20     if top_p < 1.0:
21         sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22         cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
24         # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
25         sorted_indices_to_remove = cumulative_probs > top_p
26         if min_tokens_to_keep > 1:
27             # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
28             sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
29         # Shift the indices to the right to keep also the first token above the threshold
30         sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
31         sorted_indices_to_remove[..., 0] = 0
32
33         # scatter sorted tensors to original indexing
34         indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
35         logits[indices_to_remove] = filter_value
36     return logits

再看看重复惩罚

 1 # 输入的同样是logits(lprobs)
 2 # 同时输入了之前出现过的词以及惩罚系数(大于1的)
 3 # 考虑到了logit是正和负时处理方式应该不一样
 4 def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
 5         """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
 6         for i in range(batch_size * num_beams):
 7             for previous_token in set(prev_output_tokens[i].tolist()):
 8                 # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
 9                 if lprobs[i, previous_token] < 0:
10                     lprobs[i, previous_token] *= repetition_penalty
11                 else:
12                     lprobs[i, previous_token] /= repetition_penalty

最后是重复词去除

 1 # 这个函数将会返回一个不可使用的词表
 2 # 生成n-gram的巧妙方式大家可以借鉴一下
 3 # 下面是一个3-gram的例子
 4 # a = [1,2,3,4,5]
 5 # for ngram in zip(*[a[i:] for i in range(3)]):
 6 #    print(ngram)
 7 def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
 8     # Copied from fairseq for no_repeat_ngram in beam_search"""
 9     if cur_len + 1 < no_repeat_ngram_size:
10         # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
11         return [[] for _ in range(num_hypos)]
12     generated_ngrams = [{} for _ in range(num_hypos)]
13     for idx in range(num_hypos):
14         gen_tokens = prev_input_ids[idx].numpy().tolist()
15         generated_ngram = generated_ngrams[idx]
16         # 就是这巧妙的一句
17         for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
18             prev_ngram_tuple = tuple(ngram[:-1])
19             generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
20     def _get_generated_ngrams(hypo_idx):
21         # Before decoding the next token, prevent decoding of ngrams that have already appeared
22         start_idx = cur_len + 1 - no_repeat_ngram_size
23         ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
24         return generated_ngrams[hypo_idx].get(ngram_idx, [])
25     banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
26     return banned_tokens

以上这些代码应该在哪里调用相信看上一篇文章的朋友都应该知道了,这里就放出来最核心的差异。

 1 if do_sample:
 2     # 这是今天的采样方式
 3     _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
 4     # Top-p/top-k filtering,这一步重建了候选集
 5     _scores = top_k_top_p_filtering(
 6         _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
 7     )  # (batch_size * num_beams, vocab_size)
 8     # re-organize to group the beam together to sample from all beam_idxs
 9     _scores = _scores.contiguous().view(
10         batch_size, num_beams * vocab_size
11     )  # (batch_size, num_beams * vocab_size)
12
13     # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
14     probs = F.softmax(_scores, dim=-1)
15     # 采样
16     next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
17     # Compute next scores
18     next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
19     # sort the sampled vector to make sure that the first num_beams samples are the best
20     next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
21     next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
22 else:
23     # 这是昨天的beam search方式
24     # 直接将log概率相加求条件概率
25     next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
26
27     # re-organize to group the beam together (we are keeping top hypothesis accross beams)
28     next_scores = next_scores.view(
29         batch_size, num_beams * vocab_size
30     )  # (batch_size, num_beams * vocab_size)
31
32     next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)

OK,谢谢各位看到这里,祝大家生成出高质量的文本!

参考资料

[1] The Curious Case of Neural Text Degeneration: https://arxiv.org/abs/1904.09751

[2] CTRL: A Conditional Transformer Language Model for Controllable Generation: https://arxiv.org/abs/1909.05858

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器

你可能感兴趣的文章