英文:
JAX VMAP Parallelization Details
问题
我想知道 vmap 内部是如何工作的。
当我使用 jax.lax.map 来矢量化代码时,我知道每个元素是依次执行的。然而,当我使用 vmap 时,我似乎是并行执行矢量化操作的。能否有人为我提供更详细的解释,并解释并行化的工作原理?Jax 是如何确定并行进程的数量的,用户能否影响此行为?
提前致谢。
英文:
I was wondering how vmap's internals work.
When I vectorize code using jax.lax.map, I know that each element is executed consecutively. However when I use vmap I execute the vectorized operation apparently in parallel. Can someone provide me with a more detailed explanation of how the parallelization works? How does Jax determine the number of parallel processes, and can this behaviour be influenced by the user?
Thanks in advance.
答案1
得分: 3
jax.vmap
是一种向量化/批处理转换,而不是一种并行转换。在内部,它将未批处理的函数转换为批处理函数,降低了对高效原始调用的映射或循环。
举个例子,这里有一个简单的函数,我们创建了一个手动循环和一个自动向量化批处理版本:
这些返回相同的结果(设计如此),但在底层却有很大区别;通过打印 jaxpr,我们可以看到循环版本需要一长串的 XLA 调用:
另一方面,vmap
版本只降低到一个广义点积:
请注意这里与并行化无关;相反,我们自动创建了一个高效的批处理版本。对于应用于更复杂函数的 vmap
也是如此:它不涉及并行化,而是以自动化方式输出您操作的高效批处理版本。
英文:
jax.vmap
is a vectorizing/batching transform, not a parallelizing transform. Internally, it converts an unbatched function to a batched function, lowering to efficient primitive calls rather than an explicit map or loop.
For example, here is a simple function, where we create both a manually-looped and an automatically vectorized batched version:
import jax
import numpy as np
import jax.numpy as jnp
def f(x, y):
return x @ y
num_batches = 3
num_entries = 5
np.random.seed(0)
x = np.random.rand(num_batches, num_entries)
y = np.random.rand(num_batches, num_entries)
f_loop = lambda x, y: jnp.stack([f(xi, yi) for xi, yi in zip(x, y)])
f_vmap = jax.vmap(f)
print(f_loop(x, y))
# [1.3567398 2.1908383 1.6315514]
print(f_vmap(x, y))
# [1.3567398 2.1908383 1.6315514]
These return the same results (by design), but they are quite different under the hood; by printing the jaxpr, we see that the loop version requires a long sequence of XLA calls:
print(jax.make_jaxpr(f_loop)(x, y))
{ lambda ; a:f32[3,5] b:f32[3,5]. let
c:f32[1,5] = slice[limit_indices=(1, 5) start_indices=(0, 0) strides=(1, 1)] a
d:f32[5] = squeeze[dimensions=(0,)] c
e:f32[1,5] = slice[limit_indices=(1, 5) start_indices=(0, 0) strides=(1, 1)] b
f:f32[5] = squeeze[dimensions=(0,)] e
g:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] d f
h:f32[1,5] = slice[limit_indices=(2, 5) start_indices=(1, 0) strides=(1, 1)] a
i:f32[5] = squeeze[dimensions=(0,)] h
j:f32[1,5] = slice[limit_indices=(2, 5) start_indices=(1, 0) strides=(1, 1)] b
k:f32[5] = squeeze[dimensions=(0,)] j
l:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] i k
m:f32[1,5] = slice[limit_indices=(3, 5) start_indices=(2, 0) strides=(1, 1)] a
n:f32[5] = squeeze[dimensions=(0,)] m
o:f32[1,5] = slice[limit_indices=(3, 5) start_indices=(2, 0) strides=(1, 1)] b
p:f32[5] = squeeze[dimensions=(0,)] o
q:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n p
r:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
s:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
t:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
u:f32[3] = concatenate[dimension=0] r s t
in (u,) }
On the other hand, the vmap
version lowers to just a single generalized dot product:
print(jax.make_jaxpr(f_vmap)(x, y))
{ lambda ; a:f32[3,5] b:f32[3,5]. let
c:f32[3] = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] a b
in (c,) }
Notice that there is nothing concerning parallelization here; rather we've automatically created an efficient batched version of our operation. The same is true of vmap
applied to more complicated functions: it does not involve parallelization, rather it outputs an efficient batched version of your operation in an automated manner.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论