英文:
How Jax use LAX-backend implementation of functions
问题
I need to compute the kron products of two arrays and I want to test if doing it using Jax is faster than doing it using Numpy.
Now, in numpy my code there is res = numpy.kron(x1,x2)
, in Jax there is jax.numpy.kron(x1,x2)
but how can I use it properly?
My doubts are:
-
is it sufficient to replace
numpy
withjax.numpy
as follows:res = jax.numpy.kron(x1,x2)
? -
should I first send x1 and x2 to the device using
x1_dev = jax.device_put(x1)
and after that runres = jax.numpy.kron(x1_dev,x2_dev)
? -
should I add
jax.block_until_ready()
to thejax.numpy.kron()
call?
英文:
I need to compute the kron procuts of two arrays and I want to test if doing it using Jax is faster than doing it using Numpy.
Now, in numpy my code there is res = numpy.kron(x1,x2)
, in Jax there is jax.numpy.kron(x1,x2)
but how can I use it properly?
My doubs are:
-
is it sufficient to replace
numpy
withjax.numpy
as follows:res = jax.numpy.kron(x1,x2)
? -
should I first sent x1 and x2 to the device using
x1_dev = jax.device_put(x1)
and after that runres = jax.numpy.kron(x1_dev,x2_dev)
? -
should I add
jax.block_until_ready()
to thejax.numpy.kron()
call?
答案1
得分: 1
这在 JAX 的 FAQ 中有详细解释,位于 JAX 代码基准测试 部分。特别是,如果你对 JAX 与 NumPy 的速度感兴趣,我建议阅读这一部分的第二节,JAX 是否比 NumPy 更快?,它提供了 JAX 何时比等效的 NumPy 代码更快或更慢的广泛概述。
至于对 kron
函数进行基准测试,根据那里的建议,我会像这样进行基准测试(使用 IPython 的 %timeit
以方便测试)。我在 Colab T4 GPU 运行时上运行了以下代码:
import numpy as np
x1 = np.random.rand(1000)
x2 = np.random.rand(1000)
%timeit np.kron(x1, x2)
6.52 ms ± 580 µs 每次循环(平均 ± 7 次运行,100 次循环的标准差)
import jax.numpy as jnp
x1_jax = jnp.array(x1)
x2_jax = jnp.array(x2)
%timeit jnp.kron(x1_jax, x2_jax).block_until_ready()
1.39 ms ± 2.28 ms 每次循环(平均 ± 7 次运行,1 次循环的标准差)
如果你想查看使用 jax.jit
进行的即时编译的效果,你可以像这样做:
import jax
jit_kron = jax.jit(jnp.kron)
_ = jit_kron(x1_jax, x2_jax) # 在计时之前触发编译
%timeit jit_kron(x1_jax, x2_jax).block_until_ready()
116 µs ± 33.9 µs 每次循环(平均 ± 7 次运行,10000 次循环的标准差)
英文:
This is covered in JAX's FAQ under Benchmarking JAX Code.
In particular, if you're interested in the speed of JAX vs NumPy I would read the second section of this, Is JAX Faster Than Numpy? which gives a broad overview of when you should expect JAX to be faster or slower than equivalent NumPy code.
As for benchmarking kron
: following the advice there, I would benchmark them like this (using IPython's %timeit
for convenience). I ran the following on a Colab T4 GPU runtime:
import numpy as np
x1 = np.random.rand(1000)
x2 = np.random.rand(1000)
%timeit np.kron(x1, x2)
6.52 ms ± 580 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
import jax.numpy as jnp
x1_jax = jnp.array(x1)
x2_jax = jnp.array(x2)
%timeit jnp.kron(x1_jax, x2_jax).block_until_ready()
1.39 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
If you want to see the effect of just-in-time compilation with jax.jit
, you can do something like this:
import jax
jit_kron = jax.jit(jnp.kron)
_ = jit_kron(x1_jax, x2_jax) # trigger compilation before timing
%timeit jit_kron(x1_jax, x2_jax).block_until_ready()
116 µs ± 33.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论