使用 jax.vmap 进行向量化以及广播。

huangapple go评论89阅读模式
英文:

using jax.vmap to vectorize along with broadcasting

问题

考虑以下玩具示例:

  1. x = np.arange(3)
  2. # np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
  3. cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
  4. cfuns = jax.vmap(cfun)
  5. # 对于一个2D的x:
  6. x = np.arange(6).reshape(3,2)
  7. cfuns(x)

其中 x-x[:,None] 是广播部分,并生成一个3x3的数组。
我希望cfuns能对x的每一行进行向量化。

  1. JAX跟踪对象Traced<ShapedArray(int64[2,2])>的numpy.ndarray转换方法__array__()被调用,带有<BatchTrace(level=1/0)>,具有
  2. val = Array([[[ 0, 1],
  3. [-1, 0]],
  4. [[ 0, 1],
英文:

Consider the following toy example:

  1. x = np.arange(3)
  2. # np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
  3. cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
  4. cfuns = jax.vmap(cfun)
  5. # for a 2d x:
  6. x = np.arange(6).reshape(3,2)
  7. cfuns(x)

where x-x[:,None] is the broadcasting part and give a 3x3 array.
I want cfuns to be vectorized over each row of x.

  1. The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced&lt;ShapedArray(int64[2,2])&gt;with&lt;BatchTrace(level=1/0)&gt; with
  2. val = Array([[[ 0, 1],
  3. [-1, 0]],
  4. [[ 0, 1],

答案1

得分: 0

JAX的转换操作,如vmapjitgrad等,与标准的numpy操作不兼容。相反,你应该使用jax.numpy,它提供了一个类似的API,构建在与JAX兼容的操作之上:

  1. import jax
  2. import jax.numpy as jnp
  3. x = jnp.arange(3)
  4. cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
  5. cfuns = jax.vmap(cfun)
  6. # 对于一个2D的x:
  7. x = jnp.arange(6).reshape(3, 2)
  8. print(cfuns(x))
  9. # [[ 0.84147096 -0.84147096]
  10. # [ 0.84147096 -0.84147096]
  11. # [ 0.84147096 -0.84147096]]
英文:

JAX transformations like vmap, jit, grad, etc. are not compatible with standard numpy operations. Instead you should use jax.numpy, which provides a similar API built on JAX-compatible operations:

  1. import jax
  2. import jax.numpy as jnp
  3. x = jnp.arange(3)
  4. cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
  5. cfuns = jax.vmap(cfun)
  6. # for a 2d x:
  7. x = jnp.arange(6).reshape(3,2)
  8. print(cfuns(x))
  9. # [[ 0.84147096 -0.84147096]
  10. # [ 0.84147096 -0.84147096]
  11. # [ 0.84147096 -0.84147096]]

huangapple
  • 本文由 发表于 2023年2月8日 21:06:52
  • 转载请务必保留本文链接:https://go.coder-hub.com/75386255.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定