How to copy elements of one pytorch tensor at given indices into another tensor without intermediate allocation or looping

huangapple go评论104阅读模式
英文:

How to copy elements of one pytorch tensor at given indices into another tensor without intermediate allocation or looping

问题

我想要执行以下操作,但不生成由 a[idx] 产生的中间缓冲区或循环遍历 idx。我该如何做到这一点?

b.index_add_(0, idx, a)
英文:

Given

import torch

a: torch.Tensor
b: torch.Tensor
assert a.shape[1:] == b.shape[1:]
idx = torch.randint(b.shape[0], [a.shape[0]])

I want to do

b[...] = a[idx]

But without intermediate buffer produced by a[idx] or looping over idx. How do I do this?

答案1

得分: 0

你可以使用 torch.index_select

torch.index_select(a, 0, idx, out=b)
英文:

You can use torch.index_select:

torch.index_select(a, 0, idx, out = b)

huangapple
  • 本文由 发表于 2023年8月5日 05:15:13
  • 转载请务必保留本文链接:https://go.coder-hub.com/76839139.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定