英文:
JAX: unable to jnp.where with known sizes inside a pmap
问题
I wanted to do a pmap of a given function, with 2D arrays that might (or might not) contain nan values. That function must then apply some operations to the finite values that exist in each row (toy examples at the end of the post).
我想对给定的函数进行pmap,使用可能(或可能不)包含nan值的2D数组。然后该函数必须对每一行中存在的有限值应用一些操作(帖子末尾有玩具示例)。
I know how many points (per row) contain NaNs, even before I /jax.jit/ anything. Thus, I should be able to:
我知道每行中包含多少个NaN点,甚至在我执行/jax.jit/之前。因此,我应该能够:
import jax.numpy as jnp
inds = jnp.where(jnp.isfinite(line), size= Finite_points_number)
但是我无法将元素的大小传递到pmap函数中。
但是,我无法将元素的大小传递到pmap函数中。
I have tried to:
i) pmap over over the list with the number of good points per row:
我尝试过:
i) 在每行的有效点数列表上使用pmap:
data_array = jnp.array([
[1,2,3,4],
[4,5,6, jnp.nan]
]
)
sizes = jnp.asarray((4, 3)) # Number of valid points per row
def jitt_function(line, N):
"""
Over-simplified function to showcase the problem
"""
inds = jnp.where(jnp.isfinite(line), size=N)
return jnp.sum(line[inds])
pmap_func = jax.pmap(jitt_function,
in_axes=(0, 0)
)
pmap_func(data_array, sizes)
and it fails with
并且失败了
> The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. The error occurred while tracing the function jitt_function at [...] for pmap. This concrete value was not available in Python because it depends on the value of the argument 'N'.
> jnp.nonzero的size参数必须在JAX变换中静态指定,以使用jnp.nonzero。在pmap中跟踪函数jitt_function时出现错误。这个具体值在Python中不可用,因为它取决于参数'N'的值。
ii) I have also tried to turn the number of points (N) into a static argument:
ii) 我还尝试将点数(N)转换为静态参数:
jitt_function = jax.jit(jitt_function, static_argnames=("N",))
pmap_func = jax.pmap(jitt_function,
in_axes=(0, 0)
)
> ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function jitt_function is non-hashable.
> ValueError:不支持不可哈希的静态参数,因为这可能导致意外的缓存未命中。函数jitt_function的类型为<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>的静态参数(索引1)是不可哈希的。
Even if I managed to transform this into a static argument, I would still need to "know" the line number, so that I could access the correct number of good points.
即使我设法将其转换为静态参数,我仍然需要“知道”行号,以便我可以访问正确的有效点数。
Question: Is there any way for me to do this within jax?
**问题:**在JAX中是否有任何方法可以做到这一点?
英文:
I wanted to do a pmap of a given function, with 2D arrays that might (or might not) contain nan values. That function must then apply some operations to the finite values that exist in each row (toy examples at the end of the post).
I know how many points (per row) contain NaNs, even before I /jax.jit/ anything. Thus, I should be able to:
import jax.numpy as jnp
inds = jnp.where(jnp.isfinite(line), size= Finite_points_number)
but I am not able to pass the size of the elements into the pmap-ed function.
I have tried to:
i) pmap over over the list with the number of good points per row:
data_array = jnp.array([
[1,2,3,4],
[4,5,6, jnp.nan]
]
)
sizes = jnp.asarray((4, 3)) # Number of valid points per row
def jitt_function(line, N):
"""
Over-simplified function to showcase the problem
"""
inds = jnp.where(jnp.isfinite(line), size=N)
return jnp.sum(line[inds])
pmap_func = jax.pmap(jitt_function,
in_axes=(0, 0)
)
pmap_func(data_array, sizes)
and it fails with
> The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. The error occurred while tracing the function jitt_function at [...] for pmap. This concrete value was not available in Python because it depends on the value of the argument 'N'.
ii) I have also tried to turn the number of points (N) into a static argument:
jitt_function = jax.jit(jitt_function, static_argnames=("N",))
pmap_func = jax.pmap(jitt_function,
in_axes=(0, 0)
)
> ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function jitt_function is non-hashable.
Even if I managed to transform this into a static argument, I would still need to "know" the line number, so that I could access the correct number of good points.
Question: Is there any way for me to do this within jax?
答案1
得分: 0
以下是翻译好的部分:
如果每个pmap
批次具有不同数量的点,则该数量不是固定的。您的预期操作的结果将是一个不规则数组(即,行具有不同数量的元素的二维数组),不支持在JAX中使用不规则数组。
如果实际上每个批次具有固定数量的元素,意思是每个批次中的数量相等,那么您可以使用jnp.where
的size
参数执行此计算。可能会看起来像这样:
from functools import partial
def jitt_function(line, N):
"""
用于展示问题的过度简化的函数
"""
inds = jnp.where(jnp.isfinite(line), size=N, fill_value=0)
return jnp.sum(line[inds])
pmap_func = jax.pmap(partial(jitt_function, N=4))
pmap_func(data_array)
如果每个批次中的条目少于指定数量,那么一个选项是指定jnp.where
的fill_value
参数来填充输出。在这种情况下,因为您沿每个维度求和,填充值为零会返回预期的结果。
英文:
If you have a different number of points per pmap
batch, then the number is not static. The result of your intended operation would be a ragged array (i.e. a 2D array whose rows have differing numbers of elements) and ragged arrays are not supported in JAX.
If you actually have a static number of elements—meaning an equal number in every batch—then you can use the size
argument of jnp.where
to do this computation. It might look something like this:
from functools import partial
def jitt_function(line, N):
"""
Over-simplified function to showcase the problem
"""
inds = jnp.where(jnp.isfinite(line), size=N, fill_value=0)
return jnp.sum(line[inds])
pmap_func = jax.pmap(partial(jitt_function, N=4))
pmap_func(data_array)
If you have fewer than the specified number of entries in each batch, then one option is to specify the fill_value
argument to jnp.where
to pad the output. In this case, since you are taking the sum along each dimension, a fill value of zero returns the expected result.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论