Striding in numpy and pytorch, How force writing to an array or "input tensor and the written-to tensor refer to a single memory location"?

huangapple go评论66阅读模式
英文:

Striding in numpy and pytorch, How force writing to an array or "input tensor and the written-to tensor refer to a single memory location"?

问题

使用numpy数组和pytorch张量进行实验时,我发现了一个与依赖于先前时间步的时间系列相关的“黑魔法”差异。

类似于numpy的库的一个关键特性是它们通常不会复制数据,而是在不同的视图中操作数据。在numpy中执行这种操作的最有趣的工具之一是使用stride_tricks的函数。

作为一个更简单的问题,我试图创建一个“原位累积求和”,但我发现numpy不会出错,而且与pytorch的看似等效的代码不同。我实际上不需要一个原位累积求和的实现,而是需要进行原位广播操作,以使用广播计算的数据。

所以我的问题是:我如何强制pytorch允许这种求和操作?如何使numpy不对我的数据进行复制并对其求和?

import numpy as np
# array = np.arange(6.)
array = np.arange(6.).reshape(2,3) + 3
# array([[3., 4., 5.],
#       [6., 7., 8.]])
s = np.lib.stride_tricks.as_strided(array, shape=(1,3)) # array([[3., 4., 5.]])
array += s
array
# array([[ 6.,  8., 10.],
#        [ 9., 11., 13.]])
import torch
t = torch.arange(6.).reshape(2,3) + 3
# tensor([[3., 4., 5.],
        # [6., 7., 8.]])
s = t.as_strided(size=(1,3),stride=(1,1)) # tensor([[3., 4., 5.]])
t += s
t
# RuntimeError                              Traceback (most recent call last)
#       3 # tensor([[3., 4., 5.],
#       4         # [6., 7., 8.]])
#       5 s = t.as_strided(size=(1,3),stride=(1,1)) # tensor([[3., 4., 5.]])
# ----> 6 t += s
#       7 t

# RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

期望结果:

#对第一行进行原位求和
#用原位修改的第一行对第二行进行求和
array([[6., 8., 10.],
      [12., 15., 18.]])
英文:

Experimenting with both numpy arrays and pytorch tensors I have came across a difference in an attempt to "do black magic" related to time series that rely on previous time steps

A key feature of the numpy-esque libraries is that they often don't copy data but manipulate it in different views. one of the most interesting tools to do that in numpy is using stride_tricks's functions.

As a simpler problem I was trying to make an "inplace cumsum" but i stumbled upon the fact that numpy doesn't error out and a seemingly equivalent code of pytorch.
I don't really need an in place cumsum implementation but rather inplace broadcasting operation that will use inbroadcast-calculated data.

So my questions are; How can I force pytorch to allow the sum?; How can I make numpy not make a copy of my data and sum over it?

import numpy as np
# array = np.arange(6.)
array = np.arange(6.).reshape(2,3) + 3
# array([[3., 4., 5.],
#       [6., 7., 8.]])
s = np.lib.stride_tricks.as_strided(array, shape=(1,3)) # array([[3., 4., 5.]])
array += s
array
# array([[ 6.,  8., 10.],
#        [ 9., 11., 13.]])
import torch
t = torch.arange(6.).reshape(2,3) + 3
# tensor([[3., 4., 5.],
        # [6., 7., 8.]])
s = t.as_strided(size=(1,3),stride=(1,1)) # tensor([[3., 4., 5.]])
t += s
t
# RuntimeError                              Traceback (most recent call last)
#       3 # tensor([[3., 4., 5.],
#       4         # [6., 7., 8.]])
#       5 s = t.as_strided(size=(1,3),stride=(1,1)) # tensor([[3., 4., 5.]])
# ----> 6 t += s
#       7 t

# RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

Desired:

#sum the first row to itself inplace
#sum the 2nd row with the inplace modified 1st row
array([[6., 8., 10.],
      [12., 15., 18.]])

答案1

得分: 1

add.at 主要用于处理重复的索引。

虽然数字匹配,但我不确定它是否正在执行你的累积求和。

英文:
In [129]: arr=np.arange(3,9).reshape(2,3);
In [130]: s=arr[None,0]  # or the strided

In [131]: np.add.at(arr,((0,1,1),slice(None)),s)

In [132]: arr
Out[132]: 
array([[ 6,  8, 10],
       [12, 15, 18]])

In [133]: s
Out[133]: array([[ 6,  8, 10]])

add.at is most used to handled duplicated indices.

While the numbers match, I'm not sure it's doing your cumsum.

huangapple
  • 本文由 发表于 2023年6月6日 10:02:34
  • 转载请务必保留本文链接:https://go.coder-hub.com/76410994.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定