Jax如何使用函数的LAX后端实现。

huangapple go评论68阅读模式
英文:

How Jax use LAX-backend implementation of functions

问题

I need to compute the kron products of two arrays and I want to test if doing it using Jax is faster than doing it using Numpy.

Now, in numpy my code there is res = numpy.kron(x1,x2), in Jax there is jax.numpy.kron(x1,x2) but how can I use it properly?
My doubts are:

  • is it sufficient to replace numpy with jax.numpy as follows: res = jax.numpy.kron(x1,x2)?

  • should I first send x1 and x2 to the device using x1_dev = jax.device_put(x1) and after that run res = jax.numpy.kron(x1_dev,x2_dev)?

  • should I add jax.block_until_ready() to the jax.numpy.kron() call?

英文:

I need to compute the kron procuts of two arrays and I want to test if doing it using Jax is faster than doing it using Numpy.

Now, in numpy my code there is res = numpy.kron(x1,x2), in Jax there is jax.numpy.kron(x1,x2) but how can I use it properly?
My doubs are:

  • is it sufficient to replace numpy with jax.numpy as follows: res = jax.numpy.kron(x1,x2)?

  • should I first sent x1 and x2 to the device using x1_dev = jax.device_put(x1) and after that run res = jax.numpy.kron(x1_dev,x2_dev)?

  • should I add jax.block_until_ready() to the jax.numpy.kron() call?

答案1

得分: 1

这在 JAX 的 FAQ 中有详细解释,位于 JAX 代码基准测试 部分。特别是,如果你对 JAX 与 NumPy 的速度感兴趣,我建议阅读这一部分的第二节,JAX 是否比 NumPy 更快?,它提供了 JAX 何时比等效的 NumPy 代码更快或更慢的广泛概述。

至于对 kron 函数进行基准测试,根据那里的建议,我会像这样进行基准测试(使用 IPython 的 %timeit 以方便测试)。我在 Colab T4 GPU 运行时上运行了以下代码:

import numpy as np
x1 = np.random.rand(1000)
x2 = np.random.rand(1000)
%timeit np.kron(x1, x2)
6.52 ms ± 580 µs 每次循环(平均 ± 7 次运行,100 次循环的标准差)
import jax.numpy as jnp
x1_jax = jnp.array(x1)
x2_jax = jnp.array(x2)
%timeit jnp.kron(x1_jax, x2_jax).block_until_ready()
1.39 ms ± 2.28 ms 每次循环(平均 ± 7 次运行,1 次循环的标准差)

如果你想查看使用 jax.jit 进行的即时编译的效果,你可以像这样做:

import jax
jit_kron = jax.jit(jnp.kron)
_ = jit_kron(x1_jax, x2_jax)  # 在计时之前触发编译
%timeit jit_kron(x1_jax, x2_jax).block_until_ready()
116 µs ± 33.9 µs 每次循环(平均 ± 7 次运行,10000 次循环的标准差)
英文:

This is covered in JAX's FAQ under Benchmarking JAX Code.

In particular, if you're interested in the speed of JAX vs NumPy I would read the second section of this, Is JAX Faster Than Numpy? which gives a broad overview of when you should expect JAX to be faster or slower than equivalent NumPy code.

As for benchmarking kron: following the advice there, I would benchmark them like this (using IPython's %timeit for convenience). I ran the following on a Colab T4 GPU runtime:

import numpy as np
x1 = np.random.rand(1000)
x2 = np.random.rand(1000)
%timeit np.kron(x1, x2)
6.52 ms ± 580 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
import jax.numpy as jnp
x1_jax = jnp.array(x1)
x2_jax = jnp.array(x2)
%timeit jnp.kron(x1_jax, x2_jax).block_until_ready()
1.39 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

If you want to see the effect of just-in-time compilation with jax.jit, you can do something like this:

import jax
jit_kron = jax.jit(jnp.kron)
_ = jit_kron(x1_jax, x2_jax)  # trigger compilation before timing
%timeit jit_kron(x1_jax, x2_jax).block_until_ready()
116 µs ± 33.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

huangapple
  • 本文由 发表于 2023年5月24日 18:10:32
  • 转载请务必保留本文链接:https://go.coder-hub.com/76322383.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定