在PyTorch中使用起始和结束索引切片1D张量。

huangapple go评论69阅读模式
英文:

Slice 1D tensor in PyTorch with tensors of start and end indexes

问题

我试图在PyTorch中从1D张量创建一组偶数切片的2D张量。假设我们有一个1D数据张量和索引张量如下:

data = torch.arange(10)
data
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

starts = torch.tensor([0, 3, 4, 1])
ends = starts + 2
starts
tensor([0, 3, 4, 1])
ends
tensor([2, 5, 6, 3])

我如何能够索引data张量,而不是循环遍历并使用每组索引进行切片,以实现如下结果:

dataSlices
tensor([[0, 1],
        [3, 4],
        [4, 5],
        [1, 2]])

我最初的明显想法是将startsends放在一个地方,就像处理单个索引一样,但它会出现错误:

data[starts:ends]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: only integer tensors of a single element can be converted to an index

我已经查看了文档的一些部分,但似乎找不到一种方法,我是否遗漏了一些明显的东西?

英文:

I am trying to create a 2D tensor of even slices from a 1D tensor in PyTorch. Say we have a 1D data tensor and tensors of indexes as:

&gt;&gt;&gt; data = torch.arange(10)
&gt;&gt;&gt; data
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
&gt;&gt;&gt; starts = torch.tensor([0, 3, 4, 1])
&gt;&gt;&gt; ends = starts + 2
&gt;&gt;&gt; starts
tensor([0, 3, 4, 1])
&gt;&gt;&gt; ends
tensor([2, 5, 6, 3])

How could I index the data tensor without looping over and slicing with each set of indexes to achieve a result as:


&gt;&gt;&gt; dataSlices
tensor([[0, 1],
        [3, 4],
        [4, 5],
        [1, 2]])

My first obvious thought is to just put the starts and ends as you would with individual indexes but it just errors out:

&gt;&gt;&gt; data[starts:ends]
Traceback (most recent call last):
  File &quot;&lt;stdin&gt;&quot;, line 1, in &lt;module&gt;
TypeError: only integer tensors of a single element can be converted to an index

I've looked through some parts of the documentation but can't seem to find a way, am I missing something obvious?

答案1

得分: 0

如果它是一个列表,zip会解决你的问题。

看起来你需要:
torch.transpose()

并使用@bachr在这个答案中提供的解决方案:
https://stackoverflow.com/a/60367265/3456886

英文:

If it were a list, zip would solve your problem

Looks like you need:
torch.transpose().

And use the solution from this answer by @bachr:
https://stackoverflow.com/a/60367265/3456886

答案2

得分: 0

EDIT:

自那以后,我找到了一种适合Python的方式来处理范围,而不使用列表推导!为此,您的结束点应该大一点,因为这种方法将使用Python范围,不包含范围的结束点。

indices = torch.stack((starts, ends), axis=1)
newtensor = torch.stack([data[slice(idx[0], idx[1])] for idx in indices])

OLD ANSWER:

您可以使用torch.take来完成这个任务。为了获得您期望的输出,您需要从结束索引中减去1,因为它接受确切的索引,而不是区间。 (或者您可以一开始就像这样生成结束索引)

indices = torch.stack((starts, ends-1), axis=1)
newtensor = torch.take(data, indices)

如果您想要获取真正的区间(根据您将索引命名为starts和ends的事实),这将是一个解决方案:

indices = torch.stack((starts, ends), axis=1)
rangeindices = [torch.arange(i[0], i[1]) for i in indices]
tensorindices = torch.stack(rangeindices).type(torch.LongTensor)
newtensor = torch.take(data, tensorindices)

但这会(可以理解地)导致与您期望的输出不同的张量。

英文:

EDIT:

Since then I found a pythonic way for the ranges, without the list comprehension! For this your ends should be bigger by one, as this method will take python ranges, wich does not contain the end of the range.

indices=torch.stack((starts,ends),axis=1)
newtensor=torch.stack([data[slice(idx[0], idx[1])] for idx in indices])

OLD ANSWER:

You can do this with torch.take. To get your desired output, you need to subtract 1 from your ends indices, as it takes exact indices, not intervals. (Alternatively you can generate ends like that in the first place)

indices=torch.stack((starts,ends-1),axis=1)
newtensor=torch.take(data,indices)

tensor([[0, 1],
        [3, 4],
        [4, 5],
        [1, 2]])

If you would want to take real intervals
(based on the fact that you named the indices starts and ends), this would be a solution for that:

indices=torch.stack((starts,ends),axis=1)
rangeindices=[torch.range(i[0],i[1]) for i in indices]
tensorindices=torch.stack(rangeindices).type(torch.LongTensor)
newtensor=torch.take(data,tensorindices)

tensor([[0, 1, 2],
        [3, 4, 5],
        [4, 5, 6],
        [1, 2, 3]])

But this would (understandably) result in a different tensor than your expected output.

huangapple
  • 本文由 发表于 2023年2月14日 20:24:44
  • 转载请务必保留本文链接:https://go.coder-hub.com/75447792.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定