英文:
How to copy elements of one pytorch tensor at given indices into another tensor without intermediate allocation or looping
问题
我想要执行以下操作,但不生成由 a[idx]
产生的中间缓冲区或循环遍历 idx
。我该如何做到这一点?
b.index_add_(0, idx, a)
英文:
Given
import torch
a: torch.Tensor
b: torch.Tensor
assert a.shape[1:] == b.shape[1:]
idx = torch.randint(b.shape[0], [a.shape[0]])
I want to do
b[...] = a[idx]
But without intermediate buffer produced by a[idx]
or looping over idx
. How do I do this?
答案1
得分: 0
你可以使用 torch.index_select
:
torch.index_select(a, 0, idx, out=b)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论