使用 jax.vmap 进行向量化以及广播。

huangapple go评论57阅读模式
英文:

using jax.vmap to vectorize along with broadcasting

问题

考虑以下玩具示例:

x = np.arange(3)
# np.sum(np.sin(x - x[:, np.newaxis]), axis=1)

cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfuns = jax.vmap(cfun)

# 对于一个2D的x:
x = np.arange(6).reshape(3,2)
cfuns(x)

其中 x-x[:,None] 是广播部分,并生成一个3x3的数组。
我希望cfuns能对x的每一行进行向量化。

JAX跟踪对象Traced<ShapedArray(int64[2,2])>的numpy.ndarray转换方法__array__()被调用,带有<BatchTrace(level=1/0)>,具有
  val = Array([[[ 0,  1],
        [-1,  0]],

       [[ 0,  1],
英文:

Consider the following toy example:

x = np.arange(3)
# np.sum(np.sin(x - x[:, np.newaxis]), axis=1)

cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfuns = jax.vmap(cfun)

# for a 2d x:
x = np.arange(6).reshape(3,2)
cfuns(x)

where x-x[:,None] is the broadcasting part and give a 3x3 array.
I want cfuns to be vectorized over each row of x.

The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced&lt;ShapedArray(int64[2,2])&gt;with&lt;BatchTrace(level=1/0)&gt; with
  val = Array([[[ 0,  1],
        [-1,  0]],

       [[ 0,  1],

答案1

得分: 0

JAX的转换操作,如vmapjitgrad等,与标准的numpy操作不兼容。相反,你应该使用jax.numpy,它提供了一个类似的API,构建在与JAX兼容的操作之上:

import jax
import jax.numpy as jnp

x = jnp.arange(3)

cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
cfuns = jax.vmap(cfun)

# 对于一个2D的x:
x = jnp.arange(6).reshape(3, 2)

print(cfuns(x))
# [[ 0.84147096 -0.84147096]
#  [ 0.84147096 -0.84147096]
#  [ 0.84147096 -0.84147096]]
英文:

JAX transformations like vmap, jit, grad, etc. are not compatible with standard numpy operations. Instead you should use jax.numpy, which provides a similar API built on JAX-compatible operations:

import jax
import jax.numpy as jnp

x = jnp.arange(3)

cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
cfuns = jax.vmap(cfun)

# for a 2d x:
x = jnp.arange(6).reshape(3,2)

print(cfuns(x))
# [[ 0.84147096 -0.84147096]
#  [ 0.84147096 -0.84147096]
#  [ 0.84147096 -0.84147096]]

huangapple
  • 本文由 发表于 2023年2月8日 21:06:52
  • 转载请务必保留本文链接:https://go.coder-hub.com/75386255.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定