英文:
How does `enforce_stop_tokens` work in LangChain with Huggingface models?
问题
In the code you provided, you can use the following tokens to enforce stop tokens for the HuggingFace model:
stop = ["」\n\n「", "」\n\n", "」\n\nWhile"]
These tokens are used to split the generated text at the point where the generation ends.
英文:
When we look at HuggingFaceHub model usage in langchain
there's this part that the author doesn't know how to stop the generation, https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py#L182:
class HuggingFacePipeline(LLM):
...
def _call(
...
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
What should I use to add the stop token to the end of the template?
If we look at https://github.com/hwchase17/langchain/blob/master/langchain/llms/utils.py, it's simply a regex split that split an input string up based on a list of stopwords, then take the first partition of the re.split
re.split("|".join(stop), text)[0]
Lets try to get a generation output from a Huggingface model, e.g.
from transformers import pipeline
from transformers import GPT2LMHeadModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
output = generator("Hey Pizza! ")
output
[out]:
[{'generated_text': 'Hey Pizza! 」\n\n「Hurry up, leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and then, Yuigahama came in contact with Ruriko in the middle of the'}]
If we apply the re.split
:
import re
def enforce_stop_tokens(text, stop):
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text)[0]
stop = ["up", "then"]
text = output[0]['generated_text']
re.split("|".join(stop), text)
[out]:
['Hey Pizza! 」\n\n「Hurry ',
', leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and ',
', Yuigahama came in contact with Ruriko in the middle of the']
But that isn't useful, I want to split at the point the generation ends. What tokens do I use to "enforce_stop_tokens"?
答案1
得分: 1
你可以通过将 eos_token_id 设置为停止词来实现这一点,我的测试中似乎可以使用一个列表。如下所示:正则表达式截取停用词,eos_token_id 在停用词后立即截断("once upon a time" 与 "once upon a")。
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import regex as re
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# 定义您的自定义停用词
stop_terms = ["right", "time"]
# 确保停用词在分词器的词汇表中
for term in stop_terms:
if term not in tokenizer.get_vocab():
tokenizer.add_tokens([term])
model.resize_token_embeddings(len(tokenizer))
def enforce_stop_tokens(text, stop):
"""一旦出现任何停用词,就截断文本。"""
return re.split("|".join(stop), text)[0]
# 获取自定义停用词的令牌 ID
eos_token_ids_custom = [tokenizer.encode(term, add_prefix_space=True)[0] for term in stop_terms]
# 生成文本
input_text = "Once upon "
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output_ids = model.generate(input_ids, eos_token_id=eos_token_ids_custom, max_length=50)
# 将输出 ID 解码为文本
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(generated_text) # Once upon a time
print("ENFORCE STOP TOKENS")
truncated_text = enforce_stop_tokens(generated_text, stop_terms)
print(truncated_text) # Once upon a
希望这对你有帮助。
英文:
You could do this by setting the eos_token_id as your stop term(s)-- in my testing it seemed to work with a list. See below: regex cuts off the stopword, eos_token_id cuts off just after the stopword ("once upon a time" vs. "once upon a")
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import regex as re
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Define your custom stop terms
stop_terms = [ "right", "time"]
# Ensure the stop terms are in the tokenizer's vocabulary
for term in stop_terms:
if term not in tokenizer.get_vocab():
tokenizer.add_tokens([term])
model.resize_token_embeddings(len(tokenizer))
def enforce_stop_tokens(text, stop):
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text)[0]
# Get the token IDs for your custom stop terms
eos_token_ids_custom = [tokenizer.encode(term, add_prefix_space=True)[0] for term in stop_terms]
# Generate text
input_text = "Once upon "
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output_ids = model.generate(input_ids, eos_token_id=eos_token_ids_custom, max_length=50)
# Decode the output IDs to text
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(generated_text) # Once upon a time
print("ENFORCE STOP TOKENS")
truncated_text = enforce_stop_tokens(generated_text, stop_terms)
print(truncated_text) # Once upon a
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论