英文:
Index matrix but return a list of lists Pytorch
问题
I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:
R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]
This makes output
as tensor([1, 4, 5, 6])
. However, I would like to have [[1], [4,5,6]]
or [tensor(1), tensor([4,5,6])]
.
I know that it could be done with a loop and using .append()
. However, I would like to avoid the use of any loop to make it faster if R
and mask
are very big.
Is there any way to do that in Python without any loop?
英文:
I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:
R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]
This makes output
as tensor([1, 4, 5, 6])
. However, I would like to have [[1], [4,5,6]]
or [tensor(1), tensor([4,5,6])]
.
I now that it could be done with a loop and ussing .append()
. However, I would like to avoid the use of any loop to make it faster if R
and mask
are very big.
Is there any way to do that in Python without any loop?
答案1
得分: 1
你可以尝试使用PyTorch的索引和广播:
output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])
这会基于掩码中每行的`True`值拆分结果张量,得到一个张量列表,其中每个元素对应原始张量的一行。
`.split()`方法返回一个张量元组,因此你可以将元组的每个元素转换为列表或张量,如下所示:
output = [tensor.tolist() for tensor in output]
这将给出一个列表,其中每个子列表对应原始张量的一行,其中`False`值已移除。
英文:
You can try using PyTorch's indexing and broadcasting:
output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])
This splits the resulting tensor based on True
values in each row of the mask, which gives a list of tensors where each element corresponds to a row of the original tensor.
The .split()
method returns a tuple of tensors, so you can cast each element of the tuple to a list or a tensor as follows:
output = [tensor.tolist() for tensor in output]
This will give a list of lists, wherein each sublist corresponds to a row of the original tensor with the False
values removed.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论