`enforce_stop_tokens` 在使用 Huggingface 模型的 LangChain 中是如何工作的?

huangapple go评论115阅读模式
英文:

How does `enforce_stop_tokens` work in LangChain with Huggingface models?

问题

In the code you provided, you can use the following tokens to enforce stop tokens for the HuggingFace model:

  1. stop = ["」\n\n「", "」\n\n", "」\n\nWhile"]

These tokens are used to split the generated text at the point where the generation ends.

英文:

When we look at HuggingFaceHub model usage in langchain there's this part that the author doesn't know how to stop the generation, https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py#L182:

  1. class HuggingFacePipeline(LLM):
  2. ...
  3. def _call(
  4. ...
  5. if stop is not None:
  6. # This is a bit hacky, but I can't figure out a better way to enforce
  7. # stop tokens when making calls to huggingface_hub.
  8. text = enforce_stop_tokens(text, stop)
  9. return text

What should I use to add the stop token to the end of the template?


If we look at https://github.com/hwchase17/langchain/blob/master/langchain/llms/utils.py, it's simply a regex split that split an input string up based on a list of stopwords, then take the first partition of the re.split

  1. re.split("|".join(stop), text)[0]

Lets try to get a generation output from a Huggingface model, e.g.

  1. from transformers import pipeline
  2. from transformers import GPT2LMHeadModel, AutoTokenizer
  3. tokenizer = AutoTokenizer.from_pretrained('gpt2')
  4. model = GPT2LMHeadModel.from_pretrained('gpt2')
  5. generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
  6. output = generator("Hey Pizza! ")
  7. output

[out]:

  1. [{'generated_text': 'Hey Pizza! \n\nHurry up, leave the place! \n\nOi! \n\nWhile eating pizza and then, Yuigahama came in contact with Ruriko in the middle of the'}]

If we apply the re.split:

  1. import re
  2. def enforce_stop_tokens(text, stop):
  3. """Cut off the text as soon as any stop words occur."""
  4. return re.split("|".join(stop), text)[0]
  5. stop = ["up", "then"]
  6. text = output[0]['generated_text']
  7. re.split("|".join(stop), text)

[out]:

  1. ['Hey Pizza! \n\nHurry ',
  2. ', leave the place! \n\nOi! \n\nWhile eating pizza and ',
  3. ', Yuigahama came in contact with Ruriko in the middle of the']

But that isn't useful, I want to split at the point the generation ends. What tokens do I use to "enforce_stop_tokens"?

答案1

得分: 1

你可以通过将 eos_token_id 设置为停止词来实现这一点,我的测试中似乎可以使用一个列表。如下所示:正则表达式截取停用词,eos_token_id 在停用词后立即截断("once upon a time" 与 "once upon a")。

  1. from transformers import GPT2LMHeadModel, GPT2Tokenizer
  2. import regex as re
  3. tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
  4. model = GPT2LMHeadModel.from_pretrained('gpt2')
  5. # 定义您的自定义停用词
  6. stop_terms = ["right", "time"]
  7. # 确保停用词在分词器的词汇表中
  8. for term in stop_terms:
  9. if term not in tokenizer.get_vocab():
  10. tokenizer.add_tokens([term])
  11. model.resize_token_embeddings(len(tokenizer))
  12. def enforce_stop_tokens(text, stop):
  13. """一旦出现任何停用词,就截断文本。"""
  14. return re.split("|".join(stop), text)[0]
  15. # 获取自定义停用词的令牌 ID
  16. eos_token_ids_custom = [tokenizer.encode(term, add_prefix_space=True)[0] for term in stop_terms]
  17. # 生成文本
  18. input_text = "Once upon "
  19. input_ids = tokenizer.encode(input_text, return_tensors='pt')
  20. output_ids = model.generate(input_ids, eos_token_id=eos_token_ids_custom, max_length=50)
  21. # 将输出 ID 解码为文本
  22. generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
  23. print(generated_text) # Once upon a time
  24. print("ENFORCE STOP TOKENS")
  25. truncated_text = enforce_stop_tokens(generated_text, stop_terms)
  26. print(truncated_text) # Once upon a

希望这对你有帮助。

英文:

You could do this by setting the eos_token_id as your stop term(s)-- in my testing it seemed to work with a list. See below: regex cuts off the stopword, eos_token_id cuts off just after the stopword ("once upon a time" vs. "once upon a")

  1. from transformers import GPT2LMHeadModel, GPT2Tokenizer
  2. import regex as re
  3. tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
  4. model = GPT2LMHeadModel.from_pretrained('gpt2')
  5. # Define your custom stop terms
  6. stop_terms = [ "right", "time"]
  7. # Ensure the stop terms are in the tokenizer's vocabulary
  8. for term in stop_terms:
  9. if term not in tokenizer.get_vocab():
  10. tokenizer.add_tokens([term])
  11. model.resize_token_embeddings(len(tokenizer))
  12. def enforce_stop_tokens(text, stop):
  13. """Cut off the text as soon as any stop words occur."""
  14. return re.split("|".join(stop), text)[0]
  15. # Get the token IDs for your custom stop terms
  16. eos_token_ids_custom = [tokenizer.encode(term, add_prefix_space=True)[0] for term in stop_terms]
  17. # Generate text
  18. input_text = "Once upon "
  19. input_ids = tokenizer.encode(input_text, return_tensors='pt')
  20. output_ids = model.generate(input_ids, eos_token_id=eos_token_ids_custom, max_length=50)
  21. # Decode the output IDs to text
  22. generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
  23. print(generated_text) # Once upon a time
  24. print("ENFORCE STOP TOKENS")
  25. truncated_text = enforce_stop_tokens(generated_text, stop_terms)
  26. print(truncated_text) # Once upon a

huangapple
  • 本文由 发表于 2023年6月15日 00:04:46
  • 转载请务必保留本文链接:https://go.coder-hub.com/76475527.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定