使用torch.argsort在CNN中

huangapple go评论54阅读模式
英文:

Using torch.argsort in CNN

问题

我的CNN网络的输出是对数概率。我想要获取输出的前10个最佳索引,并进一步使用这些索引来计算损失。

因此,我必须使用torch.argsort。但是使用torch.argsort会破坏梯度。

请问有人可以帮助我该如何继续吗?
非常感谢您的帮助。

英文:

Output of my CNN network is log probabilities. I want to get 10 best indices of the output and further use the indices to get loss.

Hence, I have to use torch.argsort. But using torch.argsort breaks gradients.
Can anyone please help me on how do I proceed?
Thank you very much for the help

答案1

得分: 1

你可以使用 torch.topk

import torch
out = torch.tensor(torch.randn(10, 100), requires_grad=True)
t_v, t_i = torch.topk(out, 10, dim=-1)
print(t_v.requires_grad)
英文:

You can use torch.topk

import torch
out = torch.tensor(torch.randn(10, 100), requires_grad=True)
t_v, t_i = torch.topk(out, 10, dim=-1)
print(t_v.requires_grad)

答案2

得分: 0

argsort在将其输出应用于掩码张量时不会_破坏梯度_,并且这将保留一些梯度:

top_indices = torch.argsort(-x)[:10] #获取前10个元素的索引
top_values = x[top_indices]
print(top_values.requires_grad)

详见 https://discuss.pytorch.org/t/differentiable-sorting-and-indices/89304

英文:

Just to correct you, argsort does not break gradients when you apply its output to mask tensors and this will keep some gradients:

top_indices = torch.argsort(-x)[:10] #get indices of top-10
top_values = x[top_indices]
print(top_values.requires_grad)

see https://discuss.pytorch.org/t/differentiable-sorting-and-indices/89304

huangapple
  • 本文由 发表于 2023年6月22日 04:53:56
  • 转载请务必保留本文链接:https://go.coder-hub.com/76527069.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定