英文:
How to slice jax arrays using jax tracer?
问题
我正在尝试修改一个代码库,使用现有数组和Jax跟踪器形式的索引来创建子数组。当我尝试直接将这些Jax跟踪器用于索引时,我会收到以下错误:
IndexError: 数组切片索引必须具有静态的起始/停止/步长,才能与NumPy索引语法一起使用。找到了slice(Tracedwith, Tracedwith, None)。要在动态位置索引静态大小的数组,请尝试使用lax.dynamic_slice/dynamic_update_slice(JAX不支持在JIT编译函数中使用动态大小的数组)。
有可能的解决方法/解决方案是什么?
英文:
I am trying to modify a code base to create a subarray using an existing array and indices in the form of Jax tracer. When I try to pass these Jax tracers directly for indices. I get the following error:
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Tracedwith, Tracedwith, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
What is a possible workaround/ solution for this?
答案1
得分: 1
以下是翻译好的内容:
有两种主要的解决方法,具体适用取决于您的问题:使用静态索引或使用 dynamic_slice
。
快速背景:在像 jit
、vmap
等 JAX 转换中使用的数组的一个限制是它们必须具有静态形状(请参阅JAX Sharp Bits: Dynamic Shapes进行一些讨论)。
考虑到这一点,像下面的 f
函数将始终失败,因为 i
和 j
是非静态变量,因此无法在编译时知道返回数组的形状:
@jit
def f(x, i, j):
return x[i:j]
对此的一种解决方法是在 jit
中将 i
和 j
设置为静态参数,以使返回数组的形状是静态的:
@partial(jit, static_argnames=['i', 'j'])
def f(x, i, j):
return x[i:j]
由于静态形状约束,这是在这种情况下使用 jit
的唯一可能解决方法。
另一种可能导致相同错误的切片问题可能如下所示:
@jit
def f(x, i):
return x[i:i + 5]
这也将导致非静态索引错误。可以像上面那样通过将 i
标记为静态来修复,但是这里有更多信息:假设 0 <= i < len(x) - 5
成立,我们知道输出数组的形状是 (5,)
。这是一个适用的情况,其中 jax.lax.dynamic_slice
(在动态位置具有固定切片大小的情况下):
@jit
def f(x, i):
return jax.lax.dynamic_slice(x, (i,), (5,))
请注意,在切片超出数组边界的情况下,这将具有不同的语义,但在大多数感兴趣的情况下,它是等效的。
还有其他例子,这两种解决方法都不适用,例如当您的程序逻辑基于创建动态长度数组时。在这些情况下,没有简单的解决方法,最好的选择要么是(1)以静态数组形状重新编写您的算法,可能使用填充的数组表示,要么是(2)不使用 JAX。
英文:
There are two main workarounds here that may be applicable depending on your problem: using static indices, or using dynamic_slice
.
Quick background: one constraint of arrays used in JAX transformations like jit
, vmap
, etc. is that they must be statically shaped (see JAX Sharp Bits: Dynamic Shapes for some discussion of this).
With that in mind, a function like f
below will always fail, because i
and j
are non-static variables and so the shape of the returned array cannot be known at compile time:
@jit
def f(x, i, j):
return x[i:j]
One workaround for this is to make i
and j
static arguments in jit
, so that the shape of the returned array will be static:
@partial(jit, static_argnames=['i', 'j'])
def f(x, i, j):
return x[i:j]
That's the only possible workaround to use jit
in such a situation, because of the static shape constraint.
Another flavor of slicing problem that can lead to the same error might look like this:
@jit
def f(x, i):
return x[i:i + 5]
This will also result in a non-static index error. It could be fixed as above by marking i
as static, but there is more information here: assuming that 0 <= i < len(x) - 5
holds, we know that the shape of the output array is (5,)
. This is a case where jax.lax.dynamic_slice
is applicable (when you have a fixed slice size at a dynamic location):
@jit
def f(x, i):
return jax.lax.dynamic_slice(x, (i,), (5,))
Note that this will have different semantics than x[i:i + 5]
in cases where the slice overruns the bounds of the array, but in most cases of interest it is equivalent.
There are other examples where neither of these two workarounds are applicable, for example when your program logic is predicated on creating dynamic-length arrays. In these cases, there is no easy work-around, and your best bet is to either (1) re-write your algorithm in terms of static array shapes, perhaps using padded array representations, or (2) not use JAX.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论