你能在PyTorch张量之间使用不规则索引进行赋值而无需使用for循环吗?

huangapple go评论102阅读模式
英文:

Can you make assignments between PyTorch tensors using ragged indices without a for loop?

问题

假设我有两个形状相同的PyTorch Tensor对象:

  1. x = torch.randn(2, 10)
  2. y = torch.randn(2, 10)

现在,我有一个索引列表(与第一个Tensor轴的长度相同),该列表提供第二个Tensor轴上不同起始位置的索引,我想要从y赋值到x,即

  1. idxs = [2, 6]
  2. for i, idx in enumerate(idxs):
  3. x[i, idx:] = y[i, idx:]

如上所示,我可以使用for循环做到这一点,但我的问题是是否有一种更高效的方法来实现这一点,而无需显式使用for循环?

英文:

Suppose I have two PyTorch Tensor objects of equal shape:

  1. import torch
  2. x = torch.randn(2, 10)
  3. y = torch.randn(2, 10)

Now, I have a list of indices (of the same length as the first Tensor axis) which give different starting positions in the second Tensor axis from which I want to assign values from y into x, i.e.,

  1. idxs = [2, 6]
  2. for i, idx in enumerate(idxs):
  3. x[i, idx:] = y[i, idx:]

As above, I can do this with a for loop, but my question is whether there is a more efficient way of doing this without an explicit for loop?

答案1

得分: 2

首先,在你的张量的第二维上创建一个索引张量:

  1. second_dim_indices = torch.arange(x.shape[1])

然后,将idxs转换为张量:

  1. idxs = torch.LongTensor(idxs)

接着,可以计算一个掩码,当张量索引需要修改时为真:

  1. mask = second_dim_indices.unsqueeze(0) >= idxs.unsqueeze(1)
  2. # 在你的情况下 =
  3. # tensor([[False, False, True, True, True, True, True, True, True, True],
  4. # [False, False, False, False, False, False, True, True, True, True]])

注意,我们必须对索引和idxs进行unsqueeze以进行>=运算。

最后,使用掩码来更新x

  1. x = y * mask + x * ~mask
英文:

First, create a index tensor on the second dimension of your tensor with

  1. second_dim_indices = torch.arange(x.shape[1])

and turn idxs into a tensor:

  1. idxs = torch.LongTensor(idxs)

Then, it is possible to compute a mask that's true when tensor index must be modified with :

  1. mask = second_dim_indices.unsqueeze(0) >= idxs.unsqueeze(1)
  2. # in your case =
  3. # tensor([[False, False, True, True, True, True, True, True, True, True],
  4. # [False, False, False, False, False, False, True, True, True, True]])

Note that we must unsqueeze indices and idxs to broadcast the >= operation.

Finally, use the mask to update x:

  1. x = y * mask + x * ~mask

huangapple
  • 本文由 发表于 2023年6月21日 23:55:56
  • 转载请务必保留本文链接:https://go.coder-hub.com/76525137.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定