JAX VMAP 并行化细节

huangapple go评论72阅读模式
英文:

JAX VMAP Parallelization Details

问题

我想知道 vmap 内部是如何工作的。
当我使用 jax.lax.map 来矢量化代码时,我知道每个元素是依次执行的。然而,当我使用 vmap 时,我似乎是并行执行矢量化操作的。能否有人为我提供更详细的解释,并解释并行化的工作原理?Jax 是如何确定并行进程的数量的,用户能否影响此行为?

提前致谢。

英文:

I was wondering how vmap's internals work.
When I vectorize code using jax.lax.map, I know that each element is executed consecutively. However when I use vmap I execute the vectorized operation apparently in parallel. Can someone provide me with a more detailed explanation of how the parallelization works? How does Jax determine the number of parallel processes, and can this behaviour be influenced by the user?

Thanks in advance.

答案1

得分: 3

jax.vmap是一种向量化/批处理转换,而不是一种并行转换。在内部,它将未批处理的函数转换为批处理函数,降低了对高效原始调用的映射或循环。

举个例子,这里有一个简单的函数,我们创建了一个手动循环和一个自动向量化批处理版本:

这些返回相同的结果(设计如此),但在底层却有很大区别;通过打印 jaxpr,我们可以看到循环版本需要一长串的 XLA 调用:

另一方面,vmap 版本只降低到一个广义点积:

请注意这里与并行化无关;相反,我们自动创建了一个高效的批处理版本。对于应用于更复杂函数的 vmap 也是如此:它不涉及并行化,而是以自动化方式输出您操作的高效批处理版本。

英文:

jax.vmap is a vectorizing/batching transform, not a parallelizing transform. Internally, it converts an unbatched function to a batched function, lowering to efficient primitive calls rather than an explicit map or loop.

For example, here is a simple function, where we create both a manually-looped and an automatically vectorized batched version:

import jax
import numpy as np
import jax.numpy as jnp

def f(x, y):
  return x @ y

num_batches = 3
num_entries = 5

np.random.seed(0)
x = np.random.rand(num_batches, num_entries)
y = np.random.rand(num_batches, num_entries)


f_loop = lambda x, y: jnp.stack([f(xi, yi) for xi, yi in zip(x, y)])
f_vmap = jax.vmap(f)

print(f_loop(x, y))
# [1.3567398 2.1908383 1.6315514]
print(f_vmap(x, y))
# [1.3567398 2.1908383 1.6315514]

These return the same results (by design), but they are quite different under the hood; by printing the jaxpr, we see that the loop version requires a long sequence of XLA calls:

print(jax.make_jaxpr(f_loop)(x, y))
{ lambda ; a:f32[3,5] b:f32[3,5]. let
    c:f32[1,5] = slice[limit_indices=(1, 5) start_indices=(0, 0) strides=(1, 1)] a
    d:f32[5] = squeeze[dimensions=(0,)] c
    e:f32[1,5] = slice[limit_indices=(1, 5) start_indices=(0, 0) strides=(1, 1)] b
    f:f32[5] = squeeze[dimensions=(0,)] e
    g:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] d f
    h:f32[1,5] = slice[limit_indices=(2, 5) start_indices=(1, 0) strides=(1, 1)] a
    i:f32[5] = squeeze[dimensions=(0,)] h
    j:f32[1,5] = slice[limit_indices=(2, 5) start_indices=(1, 0) strides=(1, 1)] b
    k:f32[5] = squeeze[dimensions=(0,)] j
    l:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] i k
    m:f32[1,5] = slice[limit_indices=(3, 5) start_indices=(2, 0) strides=(1, 1)] a
    n:f32[5] = squeeze[dimensions=(0,)] m
    o:f32[1,5] = slice[limit_indices=(3, 5) start_indices=(2, 0) strides=(1, 1)] b
    p:f32[5] = squeeze[dimensions=(0,)] o
    q:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n p
    r:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
    s:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
    t:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
    u:f32[3] = concatenate[dimension=0] r s t
  in (u,) }

On the other hand, the vmap version lowers to just a single generalized dot product:

print(jax.make_jaxpr(f_vmap)(x, y))
{ lambda ; a:f32[3,5] b:f32[3,5]. let
    c:f32[3] = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] a b
  in (c,) }

Notice that there is nothing concerning parallelization here; rather we've automatically created an efficient batched version of our operation. The same is true of vmap applied to more complicated functions: it does not involve parallelization, rather it outputs an efficient batched version of your operation in an automated manner.

huangapple
  • 本文由 发表于 2023年3月31日 02:35:59
  • 转载请务必保留本文链接:https://go.coder-hub.com/75891826.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定