如何对PyTorch张量进行切片操作的矢量化?

huangapple go评论63阅读模式
英文:

How to vectorize slicing operations on torch tensors?

问题

我有一堆由列表理解表达的变量。我想将其转换为torch.tensor,到目前为止,我得到了以下代码:

import torch
n = 10
y = torch.rand(n ** 2, requires_grad=True)
one_node_per_position = torch.FloatTensor([sum(y[k:k + n]) - 1 for k in range(0, n ** 2, n)])
one_node_per_point = torch.FloatTensor([sum(y[j::n]) - 1 for j in range(n)])
connectivity = torch.FloatTensor([sum(y[k:k + n]) - sum(y[k - n:k]) for k in range(n, n ** 2, n)])

但显然这不太好。我如何重写它以获得进一步使用向量化的优势?

英文:

I have a bunch of variables expressed by list comprehensions. I want to turn it into torch.tensor, so far I got

import torch
n = 10
y = torch.rand(n ** 2, requires_grad=True)
one_node_per_position = torch.FloatTensor([sum(y[k:k + n]) - 1 for k in range(0, n ** 2, n)])
one_node_per_point = torch.FloatTensor([sum(y[j::n]) - 1 for j in range(n)])
connectivity = torch.FloatTensor([sum(y[k:k + n]) - sum(y[k - n:k]) for k in range(n, n ** 2, n)])

But it obviously doesn't look good. How can I rewrite it to get advantage of vectorization for further using?

答案1

得分: 1

您可以使用as_strided操作符来使用矢量化操作生成原始数组的视图,并对这些视图应用求和:

import torch

n = 10
y = torch.rand(n**2, requires_grad=True)

# 矢量化版本
vec_one_node_per_position = torch.as_strided(y, (n, n), (n, 1)).sum(axis=-1) - 1
vec_one_node_per_point = torch.as_strided(y, (n, n), (1, n)).sum(axis=-1) - 1
vec_connectivity = (
    torch.as_strided(y, (n - 1, n), (n, 1), n).sum(axis=-1) 
    - torch.as_strided(y, (n - 1, n), (n, 1)).sum(axis=-1)
)

# 与基于推导的版本保持一致
one_node_per_position = torch.FloatTensor([sum(y[k : k + n]) - 1 for k in range(0, n**2, n)])
one_node_per_point = torch.FloatTensor([sum(y[j::n]) - 1 for j in range(n)])
connectivity = torch.FloatTensor([sum(y[k : k + n]) - sum(y[k - n : k]) for k in range(n, n**2, n)])

assert torch.isclose(vec_one_node_per_position, one_node_per_position).all()
assert torch.isclose(vec_one_node_per_point, one_node_per_point).all()
assert torch.isclose(vec_connectivity, connectivity).all()

请注意,我已经将代码部分保留在原样,只翻译了注释和文本。

英文:

You can use the as_strided operator to generate views of the original array using vectorized operations, and apply the sum to these views:

import torch

n = 10
y = torch.rand(n**2, requires_grad=True)

# Vectorized version
vec_one_node_per_position = torch.as_strided(y, (n, n), (n, 1)).sum(axis=-1) - 1
vec_one_node_per_point = torch.as_strided(y, (n, n), (1, n)).sum(axis=-1) - 1
vec_connectivity = (
    torch.as_strided(y, (n - 1, n), (n, 1), n).sum(axis=-1) 
    - torch.as_strided(y, (n - 1, n), (n, 1)).sum(axis=-1)
)


# Ensure consistency with comprehension-based version
one_node_per_position = torch.FloatTensor([sum(y[k : k + n]) - 1 for k in range(0, n**2, n)])
one_node_per_point = torch.FloatTensor([sum(y[j::n]) - 1 for j in range(n)])
connectivity = torch.FloatTensor([sum(y[k : k + n]) - sum(y[k - n : k]) for k in range(n, n**2, n)])


assert torch.isclose(vec_one_node_per_position, one_node_per_position).all()
assert torch.isclose(vec_one_node_per_point, one_node_per_point).all()
assert torch.isclose(vec_connectivity, connectivity).all()

huangapple
  • 本文由 发表于 2023年2月19日 20:58:08
  • 转载请务必保留本文链接:https://go.coder-hub.com/75500307.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定