ImportError: 无法从’llama_index.llms’导入名称’CustomLLM’

huangapple go评论75阅读模式
英文:

ImportError: cannot import name 'CustomLLM' from 'llama_index.llms'

问题

我在使用 llama_index 时遇到了困难。我想加载一个自定义的LLM来使用它。幸运的是,他们在文档中提供了我需要的示例,但不幸的是,它不起作用!
在他们的示例中,有以下导入部分:

from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

但当我运行它时,我会得到以下错误:

ImportError: 无法从 'llama_index.llms' 导入名称 'CustomLLM'

我的 llama_index 版本是0.7.1(当前最新版本)。你知道有没有任何解决办法让我在 llama_index 中使用自定义数据集?

P.S. 如果需要完整的代码,这里是:

import torch
from transformers import pipeline
from typing import Optional, List, Mapping, Any

from llama_index import (
    ServiceContext, 
    SimpleDirectoryReader, 
    LangchainEmbedding, 
    ListIndex
)
from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata


# 设置上下文窗口大小
context_window = 2048
# 设置输出令牌的数量
num_output = 256

# 在LLM类之外存储管道/模型以避免内存问题
model_name = "facebook/opt-iml-max-30b"
pipeline = pipeline("text-generation", model=model_name, device="cuda:0", model_kwargs={"torch_dtype":torch.bfloat16})

class OurLLM(CustomLLM):

    @property
    def metadata(self) -> LLMMetadata:
        """获取LLM元数据。"""
        return LLMMetadata(
            context_window=context_window, num_output=num_output
        )

    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt_length = len(prompt)
        response = pipeline(prompt, max_new_tokens=num_output)[0]["generated_text"]

        # 仅返回新生成的令牌
        text = response[prompt_length:]
        return CompletionResponse(text=text)
    
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        raise NotImplementedError()

# 定义我们的LLM
llm = OurLLM()

service_context = ServiceContext.from_defaults(
    llm=llm, 
    context_window=context_window, 
    num_output=num_output
)

# 加载您的数据
documents = SimpleDirectoryReader('./data').load_data()
index = ListIndex.from_documents(documents, service_context=service_context)

# 查询并打印响应
query_engine = index.as_query_engine()
response = query_engine.query("<query_text>")
print(response)
英文:

I'm having difficulties to work with llama_index. I want to load a custom LLM to use it. Fortunately, they have the exact example for my need on their documentation, unfortunately, it does not work!
They have these imports in their example:

from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

And when I run it I'll get this error:

ImportError: cannot import name &#39;CustomLLM&#39; from &#39;llama_index.llms&#39;

My llama_index version is 0.7.1 (the last current version). Do you know any workaround for me to use a custom dataset in llama_index?

P.S. If their full code is needed here it is:

import torch
from transformers import pipeline
from typing import Optional, List, Mapping, Any

from llama_index import (
    ServiceContext, 
    SimpleDirectoryReader, 
    LangchainEmbedding, 
    ListIndex
)
from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata


# set context window size
context_window = 2048
# set number of output tokens
num_output = 256

# store the pipeline/model outisde of the LLM class to avoid memory issues
model_name = &quot;facebook/opt-iml-max-30b&quot;
pipeline = pipeline(&quot;text-generation&quot;, model=model_name, device=&quot;cuda:0&quot;, model_kwargs={&quot;torch_dtype&quot;:torch.bfloat16})

class OurLLM(CustomLLM):

    @property
    def metadata(self) -&gt; LLMMetadata:
        &quot;&quot;&quot;Get LLM metadata.&quot;&quot;&quot;
        return LLMMetadata(
            context_window=context_window, num_output=num_output
        )

    def complete(self, prompt: str, **kwargs: Any) -&gt; CompletionResponse:
        prompt_length = len(prompt)
        response = pipeline(prompt, max_new_tokens=num_output)[0][&quot;generated_text&quot;]

        # only return newly generated tokens
        text = response[prompt_length:]
        return CompletionResponse(text=text)
    
    def stream_complete(self, prompt: str, **kwargs: Any) -&gt; CompletionResponseGen:
        raise NotImplementedError()

# define our LLM
llm = OurLLM()

service_context = ServiceContext.from_defaults(
    llm=llm, 
    context_window=context_window, 
    num_output=num_output
)

# Load the your data
documents = SimpleDirectoryReader(&#39;./data&#39;).load_data()
index = ListIndex.from_documents(documents, service_context=service_context)

# Query and print response
query_engine = index.as_query_engine()
response = query_engine.query(&quot;&lt;query_text&gt;&quot;)
print(response)

答案1

得分: 1

需要更改您的导入库。

from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

更改为

from llama_index.llms.custom import CustomLLM
from llama_index.llms.base import CompletionResponse, LLMMetadata
英文:

You need to change your import library

Change

from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

To this

from llama_index.llms.custom import CustomLLM
from llama_index.llms.base import CompletionResponse, LLMMetadata

huangapple
  • 本文由 发表于 2023年7月6日 13:32:46
  • 转载请务必保留本文链接:https://go.coder-hub.com/76625768.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定