如何在使用Hugging Face进行推断时添加异常处理?

huangapple go评论48阅读模式
英文:

how to add exception handling during inference with hugging face?

问题

import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

# 创建一个 DistilBertTokenizer 实例
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# 创建一个 DistilBertForSequenceClassification 模型实例
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

# 准备输入文本并编码成模型可接受的格式
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

# 禁用梯度计算,通常在推断时这样做
with torch.no_grad():
    # 通过模型进行推断
    logits = model(**inputs).logits

# 找到预测类别的标识并获取其对应的标签
predicted_class_id = logits.argmax().item()
predicted_label = model.config.id2label[predicted_class_id]

# 打印预测结果
print(predicted_label)
英文:

I'm trying to download a model from hugging face and run a inference. I'm new to torch library, can someone point me to it's documentation . basedon sample code below,

with torch.no_grad()
  ...

trying to add some exception handling to this code. not sure , what the above line does , but is there any exception in the torch library that i can use to catch , if there is an exception, if the above line fails.

import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]

答案1

得分: 1

with torch.no_grad()禁用了梯度计算。这意味着在计算损失的梯度期间,张量不会被存储。这可以加快推断速度并减少内存使用。

有关更多信息,请查看文档:
https://pytorch.org/docs/stable/generated/torch.no_grad.html

PyTorch本身不提供处理此类异常的实用工具。您可以使用普通的Try/Except,但这通常不是必要的,因为torch.no_grad()不会引发异常。如果您担心代码块中的逻辑,可以简单地将其包装在try中:

with torch.no_grad():
    try:
        logits = model(**inputs).logits
    except Exception as e:
        print(f"发生错误:{e}")
        # 退出、执行一些备选操作或检查错误的类型
# ...
英文:

First off,

> not sure, what the above line does

with torch.no_grad() disables gradient calculation. That means, Tensors aren't stored during the gradient calculations of your loss. This can speed up inference and uses less memory.

See the docs for more info about it: <br>
https://pytorch.org/docs/stable/generated/torch.no_grad.html


> Is there any exception in the torch library that I can use to catch, if there is an exception, if the above line fails?

PyTorch itself doesn't provide utilities for such exception handling. You can use a normal Try/Except but this shouldn't be necessary because torch.no_grad() doesn't raise exceptions. If you're worried about the logic within the block, you can simply wrap it inside a try:

with torch.no_grad():
    try:
        logits = model(**inputs).logits
    except Exception as e:
        print(f&quot;An error occurred: {e}&quot;)
        # exit, some fallback or check the type of the error
# ...

huangapple
  • 本文由 发表于 2023年7月20日 11:11:56
  • 转载请务必保留本文链接:https://go.coder-hub.com/76726452.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定