使用Transformers做限制集束搜索(Constrained Beam Search)的文本生成
平均阅读时长为 4分钟
MLNLP ( 机器学习算法与自然语言处理 )社区是国内外知名自然语言处理社区,受众覆盖国内外NLP硕博生、高校老师以及企业研究人员。
社区的愿景 是促进国内外自然语言处理,机器学习学术界、产业界和广大爱好者之间的交流,特别是初学者同学们的进步。
(
机器学习算法与自然语言处理 )社区是国内外知名自然语言处理社区,受众覆盖国内外NLP硕博生、高校老师以及企业研究人员。
来源 | https://huggingface.co/blog/constrained-beam-search
翻译 | 金属成色
1
『简介』
开始前,我们需要先熟悉beam search技术,详见How to generate text: using different decoding methods for language generation with Transformers,或中文翻译版
不像一般的beam search,constrained beam search可以对生成文本施加控制,因为很多时候,我们是明确知道生成文本之中是应该包含哪些内容的。例如在神经网络机器翻译任务中,通过查词典,我们明确知道生成文本应该包含的专业词汇。有时,生成的结果是满足语言模型的要求,但只是因为未包含部分关键信息,就不能够满足用户的需求。这些场景都需要让用户告诉模型,哪些词汇必须被包含在生成的内容之中。
2
『难点在哪』
然而,做到这个并不简单。因为这个任务需要我们在最终生成结果中的某个位置强制生成子内容。比如,我们需要生成一个句子,必须包含内容,和是按顺序排列的。我们定义想要的句子如下:
难点在于beam search是在token粒度生成句子,我们可以把beam search简化为一个函数(虽然不完全准确),也就是说,输入从位置至的token,预测位置的token。但是这个函数是如何知道,在位置$i<k$时,强制生成的token是要等到将来生成;在位置$i=k$时,强制生成的token是要在当前生成而不是等到$i>k$时呢。</k$时,强制生成的token是要等到将来生成;在位置$i=k$时,强制生成的token是要在当前生成而不是等到$i>
3
『例子一:强制词(Forcing a Word)』
How old are you?
到德文时,非正式场景可以说Wie alt bist du?
,正式场景可以说Wie alt sind Sie?
。这需要依赖上下文,与上下文保持一致,我们如何告诉模型这样做呢。传统集束搜索(Traditional Beam Search)
pip install -q git+https://github.com/huggingface/transformers.git
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
Wie alt bist du?
限制集束搜索(Constrained Beam Search)
force_words_ids
来实现控制模型生成结果tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
force_words = ["Sie"]
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=5,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
Wie alt sind Sie?
4
『例子二:分离约束(Disjunctive Constraints)』
["raining", "rained", "rains",...]
中的任何一个都可以呢。更进一步,我们经常不想要精确到一个字母都不差的词来作为强制输出子内容。from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
force_words_ids = [
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(tokenizer.decode(outputs[1], skip_special_tokens=True))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Output:
----------------------------------------------------------------------------------------------------
The soldiers, who were all scared and screaming at each other as they tried to get out of the
The child was taken to a local hospital where she screamed and scared for her life, police said.
screaming
,生成的第二句使用了screamed
,同时也都使用了scared
。5
『传统集束搜索(Traditional Beam Search)』
num_beams=3
的集束搜索的第一步的另一种展示方式如下:The dog
,而集束搜索会允许进一步考虑The nice
和The car
。num_beams
。我们不能配置num_beams
太大,因为对于n步的生成,我们要计算个分支。随着num_beams
变大,分支数会快速变大,例如num_beams=10
计算10步,就意味着10,000,000,000
个分支。<eos>
,或者生成token数到达上线。在计算的每一步都会经历,列出所有生成分支、排序、减少分支至num_beams、重复计算。6
『限制集束搜索(Constrained Beam Search)』
"is fast"
。top k
个概率最高的下一个token,然后把它们都加入至考虑范围。在限制集束搜索中,我们仍然会这样做,不过我们也会加入我们的强制生成token。dog
和nice
,同时我们也把强制生成tokenis
也放入考虑分支中,从而尽可能生成我们想要的短语is fast
。Banks
is fast
时,大多时候,我们得到的是不符合逻辑的输出,例如The is fast
。这实际上是一个较复杂问题。在huggingface/transformers的request issue中有深入讨论这个问题的复杂性。num_beams=3
,我们只保留三个分支,所以留下了["The is fast", "The dog is", "The dog and"]
,分别对应概率最高的Bank 2、Bank 1、 Bank 0。"The is fast"
完全满足我们的强制限制,但它是不符合常识的短语。幸运的是,我们还有"The dog is"
、"The dog and"
分支可以在后面的步骤中继续计算,它们很有希望会输出更符合常识的结果,进而在BANK 2的排序中替换掉"The is fast"
。"The is fast"
分支的下一个token预测,不再需要加入强制限制token了,因为强制限制token已经完全满足了。同时注意分支如"The dog is slow"
或"The dog is mad"
,它们虽然包含了限制词"is"
,但是在"is"
后面加入了"slow"
。因此只能重新开始生成"is fast"
,所以它们从Bank 1回到了Bank 0。"The dog is fast"
,即满足了强制限制的短语,又满足较高的输出概率,即符合常识。"The is fast"
已经在轮序调度选择(round-robin selection)中被排除掉了,因为它只在Bank 2中排到最后一名,如上图所示。7
『关于Constraint Classes、Custom Constraints的更多信息』
model.generate()
函数中我们有了force_words_ids
来控制强制生成,但我们可以做一个更好的实施设计。我们把每个限制设计成一个限制对象,它们在集束搜索过程中,分别记录下一个限制生成的token,如下所示:from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PhrasalConstraint
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
constraints = [
PhrasalConstraint(
tokenizer("Sie", add_special_tokens=False).input_ids
)
]
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
Wie alt sind Sie?
OrderedConstraints
, TemplateConstraints
,也许将来可以加进来。当前的限制类只是为了满足生成结果包含子内容,它在生成结果的位置没有关系。例如,一个刚才的例子是scared
后面接screaming
,另一个是screamed
后面接scared
。OrderedConstraints
可以允许用户指定这些顺序限制。TemplateConstraints
可以允许用户输入更多特征,例如:starting_text = "The woman"
template = ["the", "", "School of", "", "in"]
possible_outputs == [
"The woman attended the Ross School of Business in Michigan.",
"The woman was the administrator for the Harvard School of Business in MA."
]
或者:
starting_text = "The woman"
template = ["the", "", "", "University", "", "in"]
possible_outputs == [
"The woman attended the Carnegie Mellon University in Pittsburgh.",
]
impossible_outputs == [
"The woman attended the Harvard University in MA."
]
或者用户不关心两个token之间有几个token,只是使用OrderedConstraint。
8
『总结』
限制生成结果必须包含短语 一些短语是有可选列表,一些是不可选的 短语生成在指定的位置的
Guided Open Vocabulary Image Captioning with Constrained Beam Search Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting Guided Generation of Cause and Effect
技术交流群邀请函
△长按添加小助手
扫描二维码添加小助手微信
关于我们
最新评论
推荐文章
作者最新文章
你可能感兴趣的文章
Copyright Disclaimer: The copyright of contents (including texts, images, videos and audios) posted above belong to the User who shared or the third-party website which the User shared from. If you found your copyright have been infringed, please send a DMCA takedown notice to [email protected]. For more detail of the source, please click on the button "Read Original Post" below. For other communications, please send to [email protected].
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。