英文:
using jax.vmap to vectorize along with broadcasting
问题
考虑以下玩具示例:
x = np.arange(3)
# np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfuns = jax.vmap(cfun)
# 对于一个2D的x:
x = np.arange(6).reshape(3,2)
cfuns(x)
其中 x-x[:,None]
是广播部分,并生成一个3x3的数组。
我希望cfuns能对x的每一行进行向量化。
JAX跟踪对象Traced<ShapedArray(int64[2,2])>的numpy.ndarray转换方法__array__()被调用,带有<BatchTrace(level=1/0)>,具有
val = Array([[[ 0, 1],
[-1, 0]],
[[ 0, 1],
英文:
Consider the following toy example:
x = np.arange(3)
# np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfuns = jax.vmap(cfun)
# for a 2d x:
x = np.arange(6).reshape(3,2)
cfuns(x)
where x-x[:,None]
is the broadcasting part and give a 3x3 array.
I want cfuns to be vectorized over each row of x.
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int64[2,2])>with<BatchTrace(level=1/0)> with
val = Array([[[ 0, 1],
[-1, 0]],
[[ 0, 1],
答案1
得分: 0
JAX的转换操作,如vmap
、jit
、grad
等,与标准的numpy
操作不兼容。相反,你应该使用jax.numpy
,它提供了一个类似的API,构建在与JAX兼容的操作之上:
import jax
import jax.numpy as jnp
x = jnp.arange(3)
cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
cfuns = jax.vmap(cfun)
# 对于一个2D的x:
x = jnp.arange(6).reshape(3, 2)
print(cfuns(x))
# [[ 0.84147096 -0.84147096]
# [ 0.84147096 -0.84147096]
# [ 0.84147096 -0.84147096]]
英文:
JAX transformations like vmap
, jit
, grad
, etc. are not compatible with standard numpy
operations. Instead you should use jax.numpy
, which provides a similar API built on JAX-compatible operations:
import jax
import jax.numpy as jnp
x = jnp.arange(3)
cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
cfuns = jax.vmap(cfun)
# for a 2d x:
x = jnp.arange(6).reshape(3,2)
print(cfuns(x))
# [[ 0.84147096 -0.84147096]
# [ 0.84147096 -0.84147096]
# [ 0.84147096 -0.84147096]]
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论