英文:
Using torch.argsort in CNN
问题
我的CNN网络的输出是对数概率。我想要获取输出的前10个最佳索引,并进一步使用这些索引来计算损失。
因此,我必须使用torch.argsort。但是使用torch.argsort会破坏梯度。
请问有人可以帮助我该如何继续吗?
非常感谢您的帮助。
英文:
Output of my CNN network is log probabilities. I want to get 10 best indices of the output and further use the indices to get loss.
Hence, I have to use torch.argsort. But using torch.argsort breaks gradients.
Can anyone please help me on how do I proceed?
Thank you very much for the help
答案1
得分: 1
你可以使用 torch.topk
import torch
out = torch.tensor(torch.randn(10, 100), requires_grad=True)
t_v, t_i = torch.topk(out, 10, dim=-1)
print(t_v.requires_grad)
英文:
You can use torch.topk
import torch
out = torch.tensor(torch.randn(10, 100), requires_grad=True)
t_v, t_i = torch.topk(out, 10, dim=-1)
print(t_v.requires_grad)
答案2
得分: 0
argsort在将其输出应用于掩码张量时不会_破坏梯度_,并且这将保留一些梯度:
top_indices = torch.argsort(-x)[:10] #获取前10个元素的索引
top_values = x[top_indices]
print(top_values.requires_grad)
详见 https://discuss.pytorch.org/t/differentiable-sorting-and-indices/89304
英文:
Just to correct you, argsort does not break gradients when you apply its output to mask tensors and this will keep some gradients:
top_indices = torch.argsort(-x)[:10] #get indices of top-10
top_values = x[top_indices]
print(top_values.requires_grad)
see https://discuss.pytorch.org/t/differentiable-sorting-and-indices/89304
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论