使用不同形状和大小的数组为JAX vmap定义正确的矢量化轴。

huangapple go评论54阅读模式
英文:

Defining the correct vectorization axes for JAX vmap with arrays of different shapes and sizes

问题

根据这个帖子的回答,以下函数'f_switch'根据索引数组动态切换多个函数(基于'jax.lax.switch'):

import jax
from jax import vmap
import jax.random as random

def g_0(x, y, z, u): return x + y + z + u
def g_1(x, y, z, u): return x * y * z * u
def g_2(x, y, z, u): return x - y + z - u
def g_3(x, y, z, u): return x / y / z / u
g_i = [g_0, g_1, g_2, g_3]

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i: jax.lax.switch(i, g_i, x, y, z, u)
  return jax.vmap(g)(i)

使用输入数组:i_ar的形状为(len_i,),x_ar、y_ar和z_ar的形状为(len_xyz,),u_ar的形状为(len_u, len_xyz),则out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)的形状为(len_i, len_xyz, len_u):

len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval=len(g_i))

len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))

len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)

out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
print('The shape of out is', out.shape)

这是有效的。但是,如果要定义f_switch函数,以便在应用于以下轴时,out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)的结果形状为(j_len, k_len, l_len):i_ar[j], x_ar[j], y_ar[j, k], z_ar[j, k], u_ar[l],我不确定如何操作。以下是这些输入数组的示例:

j_len = 82
k_len = 20
l_len = 100
i_ar = random.randint(random.PRNGKey(0), shape=(j_len,), minval=0, maxval=len(g_i))
x_ar = random.uniform(random.PRNGKey(1), shape=(j_len,))
y_ar = random.uniform(random.PRNGKey(2), shape=(j_len, k_len))
z_ar = random.uniform(random.PRNGKey(3), shape=(j_len, k_len))
u_ar = random.uniform(random.PRNGKey(4), shape=(l_len,))

我尝试过使用嵌套的vmap来解决这个问题:

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
  g_map = jax.vmap(g, in_axes=(None, 0, 0, 0, 0))
  wrapper = lambda x, y, z, u: g_map(i, x, y, z, u)
  return jax.vmap(wrapper, in_axes=(0, None, None, None, 0))(x, y, z, u)

并广播u_ar:u_ar_broadcast = jnp.broadcast_to(u_ar, (j_len, k_len, l_len)),然后将其应用到原始的f_switch内。但是,这两种尝试都失败了。

英文:

Following the answer to this post, the following function that
'f_switch' that dynamically switches between multiple functions based on an index array is defined (based on 'jax.lax.switch'):

import jax
from jax import vmap;
import jax.random as random

def g_0(x, y, z, u): return x + y + z + u
def g_1(x, y, z, u): return x * y * z * u
def g_2(x, y, z, u): return x - y + z - u
def g_3(x, y, z, u): return x / y / z / u
g_i = [g_0, g_1, g_2, g_3]


@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i: jax.lax.switch(i, g_i, x, y, z, u)
  return jax.vmap(g)(i)

With input arrays: i_ar of shape (len_i,), x_ar y_ar and z_ar of shapes (len_xyz,) and u_ar of shape (len_u, len_xyz), out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar), yields out of shape
(len_i, len_xyz, len_u):

len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval= len(g_i)) #related to 

len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))

len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)

out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
print('The shape of out is', out.shape)

This worked. **But, How can the f_switch function be defined such that the result out of out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar) has a shape of (j_len, k_len, l_len) when the function is applied along the following axes: i_ar[j], x_ar[j], y_ar[j, k], z_ar[j, k], u_ar[l]? I am not sure about how ** Examples of these input arrays are here:

j_len = 82;
k_len = 20;
l_len = 100;
i_ar = random.randint(random.PRNGKey(0), shape=(j_len,), minval=0, maxval=len(g_i))
x_ar = random.uniform(random.PRNGKey(1), shape=(j_len,))
y_ar = random.uniform(random.PRNGKey(2), shape=(j_len,k_len))
z_ar = random.uniform(random.PRNGKey(3), shape=(j_len,k_len))
u_ar = random.uniform(random.PRNGKey(4), shape=(l_len,))

I tried to resolve this (i.e. with given input array to get output of shape: (j_len, k_len, l_len), with a nested vmap:

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
  g_map = jax.vmap(g, in_axes=(None, 0, 0, 0, 0))
  wrapper = lambda x, y, z, u: g_map(i, x, y, z, u)
  return jax.vmap(wrapper, in_axes=(0, None, None, None, 0))(x, y, z, u)

and to broadcast u_ar: u_ar_broadcast = jnp.broadcast_to(u_ar, (j_len, k_len, l_len)), and then apply it inside of the original f_switch. But, both of these attempts failed.

答案1

得分: 0

看起来你可能想要类似这样的代码?

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
  g = jax.vmap(g, (None, None, None, None, 0))
  g = jax.vmap(g, (None, None, 0, 0, None))
  g = jax.vmap(g, (0, 0, 0, 0, None))
  return g(i, x, y, z, u)

out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
print(out.shape)
# (82, 20, 100)

你应该从底部向顶部阅读in_axes(因为底部的vmap是外部的,因此首先应用于输入)。从形状上来看,你可以将映射的效果想象成如下的示意:

(0, 0, 0, 0, None)          -> (i,     x,     y[20],    z[20],    u[100])
(None, None, 0, 0, None)    -> (i,     x,     y,        z,        u[100])
(None, None, None, None, 0) -> (i,     x,     y,        z,        u)

也就是说,通常情况下,使用类似于numpy的广播方式比多层嵌套的vmap 更容易。例如,你也可以像这样操作:

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
  return jax.vmap(g, in_axes=(0, 0, 0, 0, None))(i, x, y, z, u)

out = f_switch(i_ar, x_ar[:, None, None], y_ar[:, :, None], z_ar[:, :, None], u_ar)
print(out.shape)
# (82, 20, 100)
英文:

It looks like maybe you want something like this?

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
  g = jax.vmap(g, (None, None, None, None, 0))
  g = jax.vmap(g, (None, None, 0, 0, None))
  g = jax.vmap(g, (0, 0, 0, 0, None))
  return g(i, x, y, z, u)

out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
print(out.shape)
# (82, 20, 100)

You should read the in_axes from bottom to top (because the bottom vmap is the outer one, and is therefore applied to the inputs first). Schematically, you can think of the effect of the maps on the shapes as something like this:

                               (i[82], x[82], y[82,20], z[82,20], u[100])
(0, 0, 0, 0, None)          -> (i,     x,     y[20],    z[20],    u[100])
(None, None, 0, 0, None)    -> (i,     x,     y,        z,        u[100])
(None, None, None, None, 0) -> (i,     x,     y,        z,        u)

That said, often it is easier to rely on numpy-style broadcasting rather than on multiple nested vmaps. For example, you could also do something like this:

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
  return jax.vmap(g, in_axes=(0, 0, 0, 0, None))(i, x, y, z, u)

out = f_switch(i_ar, x_ar[:, None, None], y_ar[:, :, None], z_ar[:, :, None], u_ar)
print(out.shape)
# (82, 20, 100)

huangapple
  • 本文由 发表于 2023年2月16日 05:13:40
  • 转载请务必保留本文链接:https://go.coder-hub.com/75465486.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定