如何使用一个数组来索引第二个数组的最后一个维度。

huangapple go评论100阅读模式
英文:

How to use an array to index the last dim of a second array

问题

假设 X.shape = (h, w, d),并且 Y.shape = (h, w) 包含在 range(d) 范围内的值(索引)。

我该如何使用 Y 中的索引来获取 X 中的元素(从最后一个维度)?

也就是说,我想要执行类似 X[Y] 的操作,它将返回一个 h x w 的数组,其中 Y 被用作最后一个维度 d 中的索引。

英文:

Assume that X.shape = (h, w, d) and that Y.shape = (h, w) contains values (indices) in range(d).

How can I take the elements in X (from the last dim) using the indices in Y?

That is, I would like to do something like X[Y] that will return an h x w array where Y was used as the indices in the last dim d.

答案1

得分: 3

I think you are looking for numpy.take_along_axis:

import numpy as np
rng = np.random.default_rng()

h, w, d = 2, 3, 4
x = rng.random((h, w, d))
y = rng.integers(0, d, (h, w))

x, y, np.take_along_axis(x, y[..., None], axis=-1)[..., 0]

Output:

(array([[[0.51705108, 0.68891581, 0.84475703, 0.77938839],
         [0.02115493, 0.47689898, 0.19786926, 0.73959225],
         [0.40821923, 0.0119006 , 0.89595898, 0.81798467]],

        [[0.60350791, 0.11501983, 0.15932539, 0.35923036],
         [0.27939872, 0.13691148, 0.47528086, 0.71320657],
         [0.98294212, 0.75039413, 0.06087527, 0.68233282]]]),
 array([[0, 2, 0],
        [3, 0, 0]]),
 array([[0.51705108, 0.19786926, 0.40821923],
        [0.35923036, 0.27939872, 0.98294212]]))

Comparison with Ulises' answer:

assert np.array_equal(np.take_along_axis(x, y[..., None], axis=-1)[..., 0], x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape))
%timeit np.take_along_axis(x, y[..., None], axis=-1)[..., 0]
%timeit x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape)

Output:

9.42 µs ± 91.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
55.8 µs ± 463 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
英文:

I think you are looking for numpy.take_along_axis:

import numpy as np
rng = np.random.default_rng()

h, w, d = 2, 3, 4
x = rng.random((h, w, d))
y = rng.integers(0, d, (h, w))

x, y, np.take_along_axis(x, y[..., None], axis=-1)[..., 0]

Output:

(array([[[0.51705108, 0.68891581, 0.84475703, 0.77938839],
         [0.02115493, 0.47689898, 0.19786926, 0.73959225],
         [0.40821923, 0.0119006 , 0.89595898, 0.81798467]],
 
        [[0.60350791, 0.11501983, 0.15932539, 0.35923036],
         [0.27939872, 0.13691148, 0.47528086, 0.71320657],
         [0.98294212, 0.75039413, 0.06087527, 0.68233282]]]),
 array([[0, 2, 0],
        [3, 0, 0]]),
 array([[0.51705108, 0.19786926, 0.40821923],
        [0.35923036, 0.27939872, 0.98294212]]))

Comparison with Ulises' answer:

assert np.array_equal(np.take_along_axis(x, y[..., None], axis=-1)[..., 0], x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape))
%timeit np.take_along_axis(x, y[..., None], axis=-1)[..., 0]
%timeit x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape)

Output:

9.42 µs ± 91.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
55.8 µs ± 463 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

答案2

得分: 2

编辑


选择元素的最短方法是在x中交换维度,然后使用choose来选择y值:

np.choose(y, x.transpose(2,0,1))

旧解决方案


假设您有xy,您可以使用y创建一个分类数组,然后使用它来在x上进行选择:

示例:

# 创建数组
h, w, d = 3, 3, 4
x = np.arange(h * w * d).reshape(h, w, d)
y = np.random.randint(d, size=(h, w))

# 分类y
cat_y = np.eye(d)[y]

# 从x中选择元素
x_sel = x[cat_y.astype(bool)]

# 重塑为原始形状
x_out = x_sel.reshape(y.shape)

或者一行代码:

xx = x[np.eye(d)[y].astype(np.bool)].reshape(y.shape)
英文:

Edit


The shortest way to select elements is swaping dimensions in x then use choose to select y values:

np.choose(y, x.transpose(2,0,1))


Old solution


suposing you have x, and y, you could create a categorical array using y and use it to select on x:

example:

#create arrays
h,w,d = 3,3,4
x = np.arange(h*w*d).reshape(h,w,d)
y = np.random.randint(d,size=(h,w))

#categorical y
cat_y =  np.eye(d)[y]

# select elements from x
x_sel = x[cat_y.astype(bool)]

#reshape to original form
x_out = x_sel.reshape(y.shape)

or in one line:

xx = x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape)

答案3

得分: 1

X[np.arange(h)[:,None], np.arange(w), Y]

这个代码的思想是使用与 Y 进行广播的数组来索引前两个维度。即 (h,1), (w,), (h,w)。这是所有高级索引的通用原理。这是在它们添加 take_along 之前必须做的事情。

ix_meshgridogrid 也可以用来创建这些数组。

英文:
X[np.arange(h)[:,None], np.arange(w), Y]

The idea is to index the first two dim with arrays that broadcast with Y. (h,1), (w,), (h,w) That's the general principle for all advanced indexing. That's we had to do before they added take_along.

ix_, meshgrid and ogrid could also be used to make these arrays.

huangapple
  • 本文由 发表于 2023年7月14日 00:29:36
  • 转载请务必保留本文链接:https://go.coder-hub.com/76681556.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定