从PyTorch张量中删除行(使用pytorch中的drop方法)

huangapple go评论89阅读模式
英文:

Delete rows from values from a torch tensor (drop method in pytorch)

问题

让我们假设我有一个PyTorch张量

import torch
x = torch.tensor([
    [1,2,3,4],
    [5,6,7,8],
    [9,10,11,12]
    ])

并且我想删除包含值[5,6,7,8]的行。我已经看到了这个答案(通过索引解决问题),这个答案(通过掩码解决问题),这个答案这个答案(通过知道索引删除行)。

在我的情况下,我知道要删除的张量的值,但不知道索引,而且张量的每一列都应该具有相同的值。

我可以尝试在这个问题中执行掩码操作,然后像这样索引行,就像这里所示:

ind = torch.nonzero(torch.all(x==torch.tensor([5,6,7,8]), dim=0))
x = torch.cat((x[:ind],x[ind+1:]))

这是有效的,但我想要一个更干净的解决方案,而不是分割张量然后再连接它,类似于pandas数据框中的drop()方法。

英文:

Let's say I have a pytorch tensor

import torch
x = torch.tensor([
    [1,2,3,4],
    [5,6,7,8],
    [9,10,11,12]
    ])

And I want to delete the row with values [5,6,7,8]. I have seen this answer (which solves the problem by indexing), this one (which solves the problem by masking), this one and this one (deleting rows knowing the index).

In my case, I know the values of the tensor I want to delete, but not the index, and the values should be the same in every column of the tensor.

I could try doing the masking in this question and then indexing the rows as shown here, something like this:

ind = torch.nonzero(torch.all(x==torch.tensor([5,6,7,8]), dim=0))
x = torch.cat((x[:ind],x[ind+1:]))

That works, but I'd like a cleaner solution than splitting the tensor and concatenating it again. Something similar to the drop() method in pandas dataframes.

答案1

得分: 0

你可以使用torch.all~(位的反转)的组合来排除与给定条件不匹配的列。这是一行代码,而不需要拆分张量。

import torch

x = torch.tensor([[1, 2, 3 ,4],
                  [5, 6, 7 ,8],
                  [9, 10, 11 ,12],])

x = x[~torch.all(x == torch.tensor([5,6,7,8]), dim=1)]

结果张量如下:

tensor([[ 1,  2,  3,  4],
        [ 9, 10, 11, 12]])
英文:

You can use the torch.all with combination of ~ NOT (inversion of bits)to exclude the column(s) that does match with the given one. It is one line code without splitting the tensor.

import torch

x = torch.tensor([[1, 2, 3 ,4],
                  [5, 6, 7 ,8],
                  [9, 10, 11 ,12],])

x = x[~torch.all(x == torch.tensor([5,6,7,8]), dim=1)]

The resulting tensor is as follows;

tensor([[ 1,  2,  3,  4],
        [ 9, 10, 11, 12]])

huangapple
  • 本文由 发表于 2023年6月22日 18:01:45
  • 转载请务必保留本文链接:https://go.coder-hub.com/76530740.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定