英文:
Creating a jax array using existing jax arrays of different lengths throws error
问题
我正在使用以下代码来使用jax数组将jax 2D数组的特定行设置为特定值:
zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)
但是,我收到以下错误:
ValueError: All input arrays must have the same shape.
当我将jnp修改为np(numpy)时,错误消失了。是否有任何方法解决这个错误?我知道一个解决方法是使用at[0,1].set()
,at[0,2:n].set()
来设置2D数组中的每个单独的数组。
英文:
I am using the following code to set a particular row of a jax 2D array to a particular value using jax arrays:
zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)
But, I am receiving the following error:
ValueError: All input arrays must have the same shape.
Upon modifying the jnp to np (numpy) the error disappears. Is there any way to resolve this error? I know one walk around this would be to set each of the separate arrays in the 2D array using at[0,1].set(), at[0,2:n].set().
答案1
得分: 1
你想要的是一个"不规则数组",目前在JAX中还没有办法实现这个。在旧版本的NumPy中,可以通过返回一个dtype为object
的数组来实现,但在新版本的NumPy中,这会导致错误,因为object
数组通常不方便且效率低下(例如,如果更新存储在对象数组中,就没有有效地执行最后一行中的索引更新操作的方法)。
根据您的用例,无论是在JAX还是NumPy中,都有几种可以使用的解决方法,包括将数组的行存储为列表,或者使用填充的2D数组表示。
我还要注意的是,JAX团队正在探索对不规则数组的本机支持(请参见例如 https://github.com/google/jax/pull/16541),但它距离一般可用还有一段距离。
英文:
What you have in mind is a "ragged array", and no, there is not currently any way to do this in JAX. In older versions of NumPy, this will work by returning an array of dtype object
, but in newer versions of NumPy this results in an error because object
arrays are generally inconvenient and inefficient to work with (for example, there's no way to efficiently do the equivalent of the index update operation in your last line if the updates are stored in an object array).
Depending on your use-case, there are several workarounds for this you might use in both JAX and NumPy, including storing the rows of your array as a list, or using a padded 2D array representation.
I'll note also that the JAX team is exploring native support for ragged arrays (see e.g. https://github.com/google/jax/pull/16541) but it's still fairly far from being generally useful.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论