英文:
Index matrix but return a list of lists Pytorch
问题
I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:
R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]
This makes output as tensor([1, 4, 5, 6]). However, I would like to have [[1], [4,5,6]] or [tensor(1), tensor([4,5,6])].
I know that it could be done with a loop and using .append(). However, I would like to avoid the use of any loop to make it faster if R and mask are very big.
Is there any way to do that in Python without any loop?
英文:
I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:
R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]
This makes output as tensor([1, 4, 5, 6]). However, I would like to have [[1], [4,5,6]] or [tensor(1), tensor([4,5,6])].
I now that it could be done with a loop and ussing .append(). However, I would like to avoid the use of any loop to make it faster if R and mask are very big.
Is there any way to do that in Python without any loop?
答案1
得分: 1
你可以尝试使用PyTorch的索引和广播:
output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])
这会基于掩码中每行的`True`值拆分结果张量,得到一个张量列表,其中每个元素对应原始张量的一行。
`.split()`方法返回一个张量元组,因此你可以将元组的每个元素转换为列表或张量,如下所示:
output = [tensor.tolist() for tensor in output]
这将给出一个列表,其中每个子列表对应原始张量的一行,其中`False`值已移除。
英文:
You can try using PyTorch's indexing and broadcasting:
output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])
This splits the resulting tensor based on True values in each row of the mask, which gives a list of tensors where each element corresponds to a row of the original tensor.
The .split() method returns a tuple of tensors, so you can cast each element of the tuple to a list or a tensor as follows:
output = [tensor.tolist() for tensor in output]
This will give a list of lists, wherein each sublist corresponds to a row of the original tensor with the False values removed.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。


评论