PyTorch CrossEntropyLoss文档示例崩溃

huangapple go评论63阅读模式
英文:

PyTorch CrossEntropyLoss documentation example crashes

问题

I apologize for the confusion, but I can't provide code-related translations without additional context or explanation. If you have any specific questions or need assistance with understanding the code you provided, please feel free to ask, and I'll be happy to help.

英文:

To make sure I'm using PyTorch CrossEntropyLoss correctly, I'm trying the examples from the documentation: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

However, the first example (target with class indices) doesn't seem to update the weights, and the second example (target with class probabilities) crashes.

Focusing on the second, being the more obvious kind of error, the complete program I'm running is

import torch
from torch import nn

# Example of target with class probabilities
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)

And the error message is

Traceback (most recent call last):
  File "crossentropy-probabilities.py", line 9, in <module>
    output = loss(input, target)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\modules\loss.py", line 948, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\functional.py", line 2422, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "C:\Users\russe\Anaconda3\envs\torch2\lib\site-packages\torch\nn\functional.py", line 2218, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1D target tensor expected, multi-target not supported 

Is the documentation in error, or am I missing something obvious?

答案1

得分: 1

您可能正在使用 PyTorch 版本 < 1.10。

根据您所使用的 PyTorch 版本,此功能可能不可用。对于版本 1.10 及更高版本,target 张量可以以密集格式(带有类别索引)或作为概率映射(软标签)提供。

您可以比较 nn.CrossEntropy 的文档页面:从 1.9.11.10

英文:

You are likely using a PyTorch version < 1.10.

Depending on the version of PyTorch you are using this feature might not be available. For version 1.10 and upwards, the target tensor can be provided either in dense format (with class indices) or as a probability map (soft labels).

You can compare the documentation page of nn.CrossEntropy: from 1.9.1 to 1.10.

huangapple
  • 本文由 发表于 2023年3月4日 01:29:26
  • 转载请务必保留本文链接:https://go.coder-hub.com/75630176.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定