你能在PyTorch张量之间使用不规则索引进行赋值而无需使用for循环吗?

huangapple go评论76阅读模式
英文:

Can you make assignments between PyTorch tensors using ragged indices without a for loop?

问题

假设我有两个形状相同的PyTorch Tensor对象:

x = torch.randn(2, 10)
y = torch.randn(2, 10)

现在,我有一个索引列表(与第一个Tensor轴的长度相同),该列表提供第二个Tensor轴上不同起始位置的索引,我想要从y赋值到x,即

idxs = [2, 6]
for i, idx in enumerate(idxs):
    x[i, idx:] = y[i, idx:]

如上所示,我可以使用for循环做到这一点,但我的问题是是否有一种更高效的方法来实现这一点,而无需显式使用for循环?

英文:

Suppose I have two PyTorch Tensor objects of equal shape:

import torch

x = torch.randn(2, 10)
y = torch.randn(2, 10)

Now, I have a list of indices (of the same length as the first Tensor axis) which give different starting positions in the second Tensor axis from which I want to assign values from y into x, i.e.,

idxs = [2, 6]
for i, idx in enumerate(idxs):
    x[i, idx:] = y[i, idx:]

As above, I can do this with a for loop, but my question is whether there is a more efficient way of doing this without an explicit for loop?

答案1

得分: 2

首先,在你的张量的第二维上创建一个索引张量:

second_dim_indices = torch.arange(x.shape[1])

然后,将idxs转换为张量:

idxs = torch.LongTensor(idxs)

接着,可以计算一个掩码,当张量索引需要修改时为真:

mask = second_dim_indices.unsqueeze(0) >= idxs.unsqueeze(1)
# 在你的情况下 =
# tensor([[False, False,  True,  True,  True,  True,  True,  True,  True,   True],
#         [False, False, False, False, False, False,  True,  True,  True,  True]])

注意,我们必须对索引和idxs进行unsqueeze以进行>=运算。

最后,使用掩码来更新x

x = y * mask + x * ~mask
英文:

First, create a index tensor on the second dimension of your tensor with

second_dim_indices = torch.arange(x.shape[1])

and turn idxs into a tensor:

idxs = torch.LongTensor(idxs)

Then, it is possible to compute a mask that's true when tensor index must be modified with :

mask = second_dim_indices.unsqueeze(0) >= idxs.unsqueeze(1)
# in your case =
#  tensor([[False, False,  True,  True,  True,  True,  True,  True,  True,   True],
#          [False, False, False, False, False, False,  True,  True,  True,  True]])

Note that we must unsqueeze indices and idxs to broadcast the >= operation.

Finally, use the mask to update x:

x = y * mask + x * ~mask

huangapple
  • 本文由 发表于 2023年6月21日 23:55:56
  • 转载请务必保留本文链接:https://go.coder-hub.com/76525137.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定