英文:
Counting nonzero elements with JAX and JIT
问题
I'm trying to write a simple Monte Carlo estimation of pi using JAX and JIT, but keep running into a TracerIntegerConversionError. How are JIT functions supposed to be designed to get around this? Thanks in advance!
import numpy as np
import jax
import jax.numpy as jnp
import time
n = 10000
@jax.jit
def jax_mc(n):
x = jnp.array(np.random.rand(n))
y = jnp.array(np.random.rand(n))
z = x**2 + y**2
a = jnp.where(z <= 1, z, 0)
count = jnp.count_nonzero(a)
return 4 * count / n
start_np = time.time()
print(jax_mc(n))
end_np = time.time()
print(end_np - start_np)
英文:
I'm trying to write a simple Monte Carlo estimation of pi using JAX and JIT, but keep running into a TracerIntegerConversionError. How are JIT functions supposed to be designed to get around this? Thanks in advance!
import numpy as np
import jax
import jax.numpy as jnp
import time
n = 10000
@jax.jit
def jax_mc(n):
x = jnp.array(np.random.rand(n))
y = jnp.array(np.random.rand(n))
z = x**2 + y**2
a = jnp.where(z <= 1, z,0)
count = jnp.count_nonzero(a)
return 4*count/n
start_np = time.time()
print(jax_mc(n))
end_np = time.time()
print(end_np-start_np)
答案1
得分: 1
以下是已翻译好的内容:
首先,JAX 中的数组大小必须是静态的(有关讨论,请参见JAX Sharp Bits: Dynamic Shapes)。在您的函数中,n
是动态的,因此在函数被 JIT 编译时,无法构造形状为 n
的数组。
其次,您在 JIT 编译的函数中使用了 numpy.random
,这可能不会按您的预期运行。您应该改用 jax.random
,请参见JAX Sharp Bits: Random Numbers以获取更多关于此的讨论。
第三,您的计时未考虑到 JAX 的异步调度。阅读Benchmarking JAX Code以获取运行这类基准测试时的一些考虑事项。
考虑到这些要点,我可能会以以下方式重写您的代码:
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import time
n = 10000
@partial(jax.jit, static_argnames=['n'])
def jax_mc(key, n):
key1, key2 = jax.random.split(key)
x = jax.random.normal(key1, shape=(n,))
y = jax.random.normal(key2, shape=(n,))
z = x**2 + y**2
a = jnp.where(z <= 1, z, 0)
count = jnp.count_nonzero(a)
return 4 * count / n
key = jax.random.PRNGKey(0)
start = time.time()
_ = jax_mc(key, n)
end = time.time()
print('编译时间:', end - start)
start = time.time()
result = jax_mc(key, n).block_until_ready()
end = time.time()
print("运行时间:", end - start)
print("结果:", result)
编译时间: 0.8846290111541748
运行时间: 0.0018401145935058594
结果: 1.5516
请注意,我已将引号从 static_argnames
参数中移除,因为在 Python 中不需要引号。
英文:
There are a few issues here.
First, array sizes must be static in JAX (for some discussion, see JAX Sharp Bits: Dynamic Shapes) In your function n
is dynamic, and so you cannot construct an array of shape n
when the function is JIT-compiled.
Second, you're using numpy.random
within a JIT-compiled function, which will most likely not behave as you expect. You should use jax.random
instead; see JAX Sharp Bits: Random Numbers for more discussion of this.
Third, your timing is not taking into account JAX's asynchronous dispatch. Read Benchmarking JAX Code for some discussion of considerations when running these kinds of benchmarks.
With these points in mind,
I would probably rewrite your code this way:
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import time
n = 10000
@partial(jax.jit, static_argnames=['n'])
def jax_mc(key, n):
key1, key2 = jax.random.split(key)
x = jax.random.normal(key1, shape=(n,))
y = jax.random.normal(key2, shape=(n,))
z = x**2 + y**2
a = jnp.where(z <= 1, z,0)
count = jnp.count_nonzero(a)
return 4*count/n
key = jax.random.PRNGKey(0)
start = time.time()
_ = jax_mc(key, n)
end = time.time()
print('compilation time:', end - start)
start = time.time()
result = jax_mc(key, n).block_until_ready()
end = time.time()
print("runtime:", end - start)
print("result:", result)
compilation time: 0.8846290111541748
runtime: 0.0018401145935058594
result: 1.5516
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论