ImportError: 无法从’llama_index.llms’导入名称’CustomLLM’

huangapple go评论127阅读模式
英文:

ImportError: cannot import name 'CustomLLM' from 'llama_index.llms'

问题

我在使用 llama_index 时遇到了困难。我想加载一个自定义的LLM来使用它。幸运的是,他们在文档中提供了我需要的示例,但不幸的是,它不起作用!
在他们的示例中,有以下导入部分:

  1. from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

但当我运行它时,我会得到以下错误:

  1. ImportError: 无法从 'llama_index.llms' 导入名称 'CustomLLM'

我的 llama_index 版本是0.7.1(当前最新版本)。你知道有没有任何解决办法让我在 llama_index 中使用自定义数据集?

P.S. 如果需要完整的代码,这里是:

  1. import torch
  2. from transformers import pipeline
  3. from typing import Optional, List, Mapping, Any
  4. from llama_index import (
  5. ServiceContext,
  6. SimpleDirectoryReader,
  7. LangchainEmbedding,
  8. ListIndex
  9. )
  10. from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata
  11. # 设置上下文窗口大小
  12. context_window = 2048
  13. # 设置输出令牌的数量
  14. num_output = 256
  15. # 在LLM类之外存储管道/模型以避免内存问题
  16. model_name = "facebook/opt-iml-max-30b"
  17. pipeline = pipeline("text-generation", model=model_name, device="cuda:0", model_kwargs={"torch_dtype":torch.bfloat16})
  18. class OurLLM(CustomLLM):
  19. @property
  20. def metadata(self) -> LLMMetadata:
  21. """获取LLM元数据。"""
  22. return LLMMetadata(
  23. context_window=context_window, num_output=num_output
  24. )
  25. def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
  26. prompt_length = len(prompt)
  27. response = pipeline(prompt, max_new_tokens=num_output)[0]["generated_text"]
  28. # 仅返回新生成的令牌
  29. text = response[prompt_length:]
  30. return CompletionResponse(text=text)
  31. def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
  32. raise NotImplementedError()
  33. # 定义我们的LLM
  34. llm = OurLLM()
  35. service_context = ServiceContext.from_defaults(
  36. llm=llm,
  37. context_window=context_window,
  38. num_output=num_output
  39. )
  40. # 加载您的数据
  41. documents = SimpleDirectoryReader('./data').load_data()
  42. index = ListIndex.from_documents(documents, service_context=service_context)
  43. # 查询并打印响应
  44. query_engine = index.as_query_engine()
  45. response = query_engine.query("<query_text>")
  46. print(response)
英文:

I'm having difficulties to work with llama_index. I want to load a custom LLM to use it. Fortunately, they have the exact example for my need on their documentation, unfortunately, it does not work!
They have these imports in their example:

  1. from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

And when I run it I'll get this error:

  1. ImportError: cannot import name &#39;CustomLLM&#39; from &#39;llama_index.llms&#39;

My llama_index version is 0.7.1 (the last current version). Do you know any workaround for me to use a custom dataset in llama_index?

P.S. If their full code is needed here it is:

  1. import torch
  2. from transformers import pipeline
  3. from typing import Optional, List, Mapping, Any
  4. from llama_index import (
  5. ServiceContext,
  6. SimpleDirectoryReader,
  7. LangchainEmbedding,
  8. ListIndex
  9. )
  10. from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata
  11. # set context window size
  12. context_window = 2048
  13. # set number of output tokens
  14. num_output = 256
  15. # store the pipeline/model outisde of the LLM class to avoid memory issues
  16. model_name = &quot;facebook/opt-iml-max-30b&quot;
  17. pipeline = pipeline(&quot;text-generation&quot;, model=model_name, device=&quot;cuda:0&quot;, model_kwargs={&quot;torch_dtype&quot;:torch.bfloat16})
  18. class OurLLM(CustomLLM):
  19. @property
  20. def metadata(self) -&gt; LLMMetadata:
  21. &quot;&quot;&quot;Get LLM metadata.&quot;&quot;&quot;
  22. return LLMMetadata(
  23. context_window=context_window, num_output=num_output
  24. )
  25. def complete(self, prompt: str, **kwargs: Any) -&gt; CompletionResponse:
  26. prompt_length = len(prompt)
  27. response = pipeline(prompt, max_new_tokens=num_output)[0][&quot;generated_text&quot;]
  28. # only return newly generated tokens
  29. text = response[prompt_length:]
  30. return CompletionResponse(text=text)
  31. def stream_complete(self, prompt: str, **kwargs: Any) -&gt; CompletionResponseGen:
  32. raise NotImplementedError()
  33. # define our LLM
  34. llm = OurLLM()
  35. service_context = ServiceContext.from_defaults(
  36. llm=llm,
  37. context_window=context_window,
  38. num_output=num_output
  39. )
  40. # Load the your data
  41. documents = SimpleDirectoryReader(&#39;./data&#39;).load_data()
  42. index = ListIndex.from_documents(documents, service_context=service_context)
  43. # Query and print response
  44. query_engine = index.as_query_engine()
  45. response = query_engine.query(&quot;&lt;query_text&gt;&quot;)
  46. print(response)

答案1

得分: 1

需要更改您的导入库。

  1. from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

更改为

  1. from llama_index.llms.custom import CustomLLM
  2. from llama_index.llms.base import CompletionResponse, LLMMetadata
英文:

You need to change your import library

Change

  1. from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

To this

  1. from llama_index.llms.custom import CustomLLM
  2. from llama_index.llms.base import CompletionResponse, LLMMetadata

huangapple
  • 本文由 发表于 2023年7月6日 13:32:46
  • 转载请务必保留本文链接:https://go.coder-hub.com/76625768.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定