英文:
jax.numpy.delete assume_unique_indices unexpected keyword argument
问题
I can provide a translation of the code-related content without the code part:
根据文档,jnp.delete有一个关键字参数"assume_unique_indices",应该使该函数在确保索引数组是整数数组且保证包含唯一条目时与jit兼容。在下面是一个最小的可重现示例:
错误消息:
删除"assume_unique_indices"使其按预期工作。
英文:
I can not seem to get the assume_unique_indices from jax.numpy working. According to the documentation here, the jnp.delete has a keyword argument "assume_unique_indices" that is supposed to make this function jit compatible when we are sure that the index array is an integer array and is guaranteed to contain unique entries.
Here is an minimum reproducible example
import jax
arr = jnp.array([1, 2, 3, 4, 5])
idx = jnp.array([0, 2, 4])
print(jax.__version__)
# Delete elements at indices idx
out = jax.numpy.delete(arr, idx, assume_unique_indices=True)
print(out) # [2 4]
The error message
0.4.8
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-12-bf0277118922> in <cell line: 9>()
7
8 # Delete elements at indices idx
----> 9 out = jax.numpy.delete(arr, idx, assume_unique_indices=False)
10
11 print(out) # [2 4]
TypeError: delete() got an unexpected keyword argument 'assume_unique_indices'
Deleting the assume_unique_indices
made it work as expected.
答案1
得分: 0
"Ok, as it turns out, the 'assume_unique_indices' is only added rather recently, updating to jax version 0.4.10 did the trick."
英文:
Ok, as it turns out, the 'assume_unique_indices' is only added rather recently, updating to jax version 0.4.10 did the trick
答案2
得分: 0
assume_unique_indices
在 https://github.com/google/jax/pull/15671 中添加,是在 JAX 版本 0.4.8 发布后添加的。如果您升级到版本 0.4.9 或更新版本,您的代码应该可以工作。
英文:
assume_unique_indices
was added in https://github.com/google/jax/pull/15671, after JAX version 0.4.8 was released. If you update to version 0.4.9 or newer, your code should work.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论