使用JAX在大型二维数组上查找最大的n个值

huangapple go评论67阅读模式
英文:

Using JAX on large 2d array to find the largest n value

问题

我试图使用JAX来加速我的行和列比较/选择操作。我有一个NxN的二维数组,每个单元格都是一个数字。我尝试从每一行中获取最大的4个数字,然后将一个1放入具有相同索引的不同矩阵C中。这里有两个约束条件:每行和每列最多可以有4个1。

为了简化示例,这是一个示例:

array a = [[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20],
[21,22,23,24,25]]
C:
[[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[1,0,0,0,0]]

我们不能选择最后一行的最后4个单元格,即使这4个单元格的值最高,所以我们只能选择第一个并将其放入C矩阵中。

对于一个不同的8x8二维数组和正确的C矩阵,这是我的尝试,实际矩阵是4000x4000,我的代码将花费10多分钟才能完成:

a_array = jnp.array([[0,3,4,0,0,12,19,22],
        [7,0,0,10,0,0,0,15],
        [12,0,0,15,16,19,0,31],
        [17,18,0,0,21,23,78,89],
        [22,2,78,0,0,1111,12,33],
        [123,0,122,10,14,0,50,60],
        [10,110,0,1231,0,110,0,61],
        [0,17,0,141,0,166,16,0]])

array_len = len(a_array)
c_matrix = jnp.zeros((array_len,array_len))
test_dict = {}
for i in range(array_len):
  test_dict[i] = 0
for i in range(array_len):
  for j in jnp.flip(jnp.argsort(jnp.array(a_array[i]))[-4:]):
    if test_dict[int(j)] < 4:
      if a_array[i][int(j)] != 0:
        test_dict[int(j)] +=1
        c_matrix = c_matrix.at[i,int(j)].set(1)
    if test_dict[int(j)] == 4:
      a_array = a_array.at[:,int(j)].set(0)

C矩阵如下:

[[0. 0. 1. 0. 0. 1. 1. 1.]
 [1. 0. 0. 1. 0. 0. 0. 1.]
 [0. 0. 0. 1. 1. 1. 0. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1.]
 [1. 0. 1. 0. 0. 1. 1. 0.]
 [1. 0. 1. 0. 1. 0. 1. 0.]
 [1. 1. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 1. 0. 0. 0. 0.]]

在这里,我首先使用字典跟踪每列中的1的数量,如果已经有4个1,就将2D数组中的该列更新为全零,因此,当jnp.argsort(...)尝试找到最高的4个单元格时,它不会考虑0。我还有一个条件检查,用于排除这种边缘情况:

jnp.argsort(jnp.array([0,0,0,0,0,0,12,0]))[-4:]

输出:

[4, 5, 7, 6]

感谢大家的帮助。

英文:

I'm trying to use JAX to speed up my row and column-wise comparison/selection, I have a 2d array NxN, and each cell is a number, I'm trying to get the highest 4 number out of a row, then put a 1 in a different matrix C with the same index. These are two constraints: Each row and column can have maximum of 4 1s.
For a simpler example:

array a = [[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20],
[21,22,23,24,25]]
C:
[[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[0,1,1,1,1],
[1,0,0,0,0]]

We can not select the last 4 cells in the last row even that 4 cells' value are the highest, so we can only select the first one and put it in the C matrix

Here is what I tried with a different 2d array 8x8 and the correct C matrix, still an example, the real matrix is 4000x4000 and my code will take 10+ mins to complete

a_array = jnp.array([[0,3,4,0,0,12,19,22],
            [7,0,0,10,0,0,0,15],
            [12,0,0,15,16,19,0,31],
            [17,18,0,0,21,23,78,89],
            [22,2,78,0,0,1111,12,33],
            [123,0,122,10,14,0,50,60],
            [10,110,0,1231,0,110,0,61],
            [0,17,0,141,0,166,16,0]])

array_len = len(a_array)
c_matrix = jnp.zeros((array_len,array_len))
test_dict = {}
for i in range(array_len):
  test_dict[i] = 0
for i in range(array_len):
  for j in jnp.flip(jnp.argsort(jnp.array(a_array[i]))[-4:]):
    if test_dict[int(j)] &lt; 4:
      if a_array[i][int(j)] != 0:
        test_dict[int(j)] +=1
        c_matrix = c_matrix.at[i,int(j)].set(1)
    if test_dict[int(j)] == 4:
      a_array = a_array.at[:,int(j)].set(0)

and the C matrix is:

[[0. 0. 1. 0. 0. 1. 1. 1.]
 [1. 0. 0. 1. 0. 0. 0. 1.]
 [0. 0. 0. 1. 1. 1. 0. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1.]
 [1. 0. 1. 0. 0. 1. 1. 0.]
 [1. 0. 1. 0. 1. 0. 1. 0.]
 [1. 1. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 1. 0. 0. 0. 0.]]

What I did here is first have a dict track of number of 1s in each column, if there are 4 1s already, update the 2d array with that column being all zeros, therefore, next for loop when jnp.argsort(...) trying to find the highest 4 cells, it wouldn't take the 0 into account, I also have a if check on cell == 0, to get rid of this edge case

jnp.argsort(jnp.array([0,0,0,0,0,0,12,0]))[-4:]

output:

[4, 5, 7, 6]

Thank you all in advance.

答案1

得分: 1

以下是翻译好的内容:

在使用JAX、NumPy或类似的Python数组库编写代码时,一个很好的经验法则是,如果你正在对数组值进行循环,结果将会很慢。相反,你应该尝试用本机向量化操作来表达你的逻辑。

在这种情况下,你无法将整个操作向量化,因为每列的数量约束意味着每一行的输出取决于所有前一行的输出。在这种情况下,lax.scan 是一个不错的选择。

考虑到这些因素,以下是我如何解决你的问题的方式:

import jax

def scan_fun(count, row):
  row = jnp.where(count >= 4, 0, row)
  _, indices = jax.lax.top_k(row, 4)
  c_row = jnp.zeros_like(row).at[indices].set(1)
  c_row = jnp.where(row == 0, 0, c_row)
  count += (c_row > 0)
  return count, c_row

_, c_matrix = jax.lax.scan(scan_fun, jnp.zeros_like(a_array[0]), a_array)
print(c_matrix)
# [[0 0 1 0 0 1 1 1]
#  [1 0 0 1 0 0 0 1]
#  [0 0 0 1 1 1 0 1]
#  [0 0 0 0 1 1 1 1]
#  [1 0 1 0 0 1 1 0]
#  [1 0 1 0 1 0 1 0]
#  [1 1 0 1 0 0 0 0]
#  [0 1 0 1 0 0 0 0]]

希望对你有所帮助!

英文:

A good rule of thumb when writing code with JAX, NumPy, or similar array libraries in Python is that if you're writing loops over array values, the result will be slow. Instead, you should try to express your logic in terms of native vectorized operations.

Here you can't vectorize the whole operation, because the number-per-column constraint means the output of each row depends on the output of all previous rows. In cases like this, lax.scan is a good option.

Here's how I would solve your problem with these things in mind:

import jax

def scan_fun(count, row):
  row = jnp.where(count &gt;= 4, 0, row)
  _, indices = jax.lax.top_k(row, 4)
  c_row = jnp.zeros_like(row).at[indices].set(1)
  c_row = jnp.where(row == 0, 0, c_row)
  count += (c_row &gt; 0)
  return count, c_row

_, c_matrix = jax.lax.scan(scan_fun, jnp.zeros_like(a_array[0]), a_array)
print(c_matrix)
# [[0 0 1 0 0 1 1 1]
#  [1 0 0 1 0 0 0 1]
#  [0 0 0 1 1 1 0 1]
#  [0 0 0 0 1 1 1 1]
#  [1 0 1 0 0 1 1 0]
#  [1 0 1 0 1 0 1 0]
#  [1 1 0 1 0 0 0 0]
#  [0 1 0 1 0 0 0 0]]

huangapple
  • 本文由 发表于 2023年5月25日 14:58:12
  • 转载请务必保留本文链接:https://go.coder-hub.com/76329622.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定