返回一个列表的列表 Pytorch

huangapple go评论48阅读模式
英文:

Index matrix but return a list of lists Pytorch

问题

I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:

R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]

This makes output as tensor([1, 4, 5, 6]). However, I would like to have [[1], [4,5,6]] or [tensor(1), tensor([4,5,6])].

I know that it could be done with a loop and using .append(). However, I would like to avoid the use of any loop to make it faster if R and mask are very big.

Is there any way to do that in Python without any loop?

英文:

I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:

R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]

This makes output as tensor([1, 4, 5, 6]). However, I would like to have [[1], [4,5,6]] or [tensor(1), tensor([4,5,6])].

I now that it could be done with a loop and ussing .append(). However, I would like to avoid the use of any loop to make it faster if R and mask are very big.

Is there any way to do that in Python without any loop?

答案1

得分: 1

你可以尝试使用PyTorch的索引和广播

output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])

这会基于掩码中每行的`True`值拆分结果张量,得到一个张量列表,其中每个元素对应原始张量的一行。
`.split()`方法返回一个张量元组,因此你可以将元组的每个元素转换为列表或张量,如下所示:

output = [tensor.tolist() for tensor in output]

这将给出一个列表,其中每个子列表对应原始张量的一行,其中`False`值已移除。
英文:

You can try using PyTorch's indexing and broadcasting:

output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])

This splits the resulting tensor based on True values in each row of the mask, which gives a list of tensors where each element corresponds to a row of the original tensor.
The .split() method returns a tuple of tensors, so you can cast each element of the tuple to a list or a tensor as follows:

output = [tensor.tolist() for tensor in output]

This will give a list of lists, wherein each sublist corresponds to a row of the original tensor with the False values removed.

huangapple
  • 本文由 发表于 2023年2月14日 20:23:14
  • 转载请务必保留本文链接:https://go.coder-hub.com/75447775.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定