创建一个使用不同长度的现有 Jax 数组的 Jax 数组会引发错误。

huangapple go评论70阅读模式
英文:

Creating a jax array using existing jax arrays of different lengths throws error

问题

我正在使用以下代码来使用jax数组将jax 2D数组的特定行设置为特定值:

zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)

但是,我收到以下错误:

ValueError: All input arrays must have the same shape.

当我将jnp修改为np(numpy)时,错误消失了。是否有任何方法解决这个错误?我知道一个解决方法是使用at[0,1].set()at[0,2:n].set()来设置2D数组中的每个单独的数组。

英文:

I am using the following code to set a particular row of a jax 2D array to a particular value using jax arrays:

zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)

But, I am receiving the following error:

ValueError: All input arrays must have the same shape.

Upon modifying the jnp to np (numpy) the error disappears. Is there any way to resolve this error? I know one walk around this would be to set each of the separate arrays in the 2D array using at[0,1].set(), at[0,2:n].set().

答案1

得分: 1

你想要的是一个"不规则数组",目前在JAX中还没有办法实现这个。在旧版本的NumPy中,可以通过返回一个dtype为object的数组来实现,但在新版本的NumPy中,这会导致错误,因为object数组通常不方便且效率低下(例如,如果更新存储在对象数组中,就没有有效地执行最后一行中的索引更新操作的方法)。

根据您的用例,无论是在JAX还是NumPy中,都有几种可以使用的解决方法,包括将数组的行存储为列表,或者使用填充的2D数组表示。

我还要注意的是,JAX团队正在探索对不规则数组的本机支持(请参见例如 https://github.com/google/jax/pull/16541),但它距离一般可用还有一段距离。

英文:

What you have in mind is a "ragged array", and no, there is not currently any way to do this in JAX. In older versions of NumPy, this will work by returning an array of dtype object, but in newer versions of NumPy this results in an error because object arrays are generally inconvenient and inefficient to work with (for example, there's no way to efficiently do the equivalent of the index update operation in your last line if the updates are stored in an object array).

Depending on your use-case, there are several workarounds for this you might use in both JAX and NumPy, including storing the rows of your array as a list, or using a padded 2D array representation.

I'll note also that the JAX team is exploring native support for ragged arrays (see e.g. https://github.com/google/jax/pull/16541) but it's still fairly far from being generally useful.

huangapple
  • 本文由 发表于 2023年7月3日 05:04:27
  • 转载请务必保留本文链接:https://go.coder-hub.com/76600803.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定