使用JAX和JIT计算非零元素数量

huangapple go评论63阅读模式
英文:

Counting nonzero elements with JAX and JIT

问题

I'm trying to write a simple Monte Carlo estimation of pi using JAX and JIT, but keep running into a TracerIntegerConversionError. How are JIT functions supposed to be designed to get around this? Thanks in advance!

import numpy as np
import jax
import jax.numpy as jnp
import time

n = 10000

@jax.jit
def jax_mc(n):
  x = jnp.array(np.random.rand(n))
  y = jnp.array(np.random.rand(n))
  z = x**2 + y**2
  a = jnp.where(z <= 1, z, 0)
  count = jnp.count_nonzero(a)
  return 4 * count / n

start_np = time.time()
print(jax_mc(n))
end_np = time.time()
print(end_np - start_np)
英文:

I'm trying to write a simple Monte Carlo estimation of pi using JAX and JIT, but keep running into a TracerIntegerConversionError. How are JIT functions supposed to be designed to get around this? Thanks in advance!

import numpy as np
import jax
import jax.numpy as jnp
import time

n = 10000

@jax.jit
def jax_mc(n):
  x = jnp.array(np.random.rand(n))
  y = jnp.array(np.random.rand(n))
  z = x**2 + y**2
  a = jnp.where(z &lt;= 1, z,0)
  count = jnp.count_nonzero(a)
  return 4*count/n

start_np = time.time()
print(jax_mc(n))
end_np = time.time()
print(end_np-start_np)

答案1

得分: 1

以下是已翻译好的内容:

首先,JAX 中的数组大小必须是静态的(有关讨论,请参见JAX Sharp Bits: Dynamic Shapes)。在您的函数中,n 是动态的,因此在函数被 JIT 编译时,无法构造形状为 n 的数组。

其次,您在 JIT 编译的函数中使用了 numpy.random,这可能不会按您的预期运行。您应该改用 jax.random,请参见JAX Sharp Bits: Random Numbers以获取更多关于此的讨论。

第三,您的计时未考虑到 JAX 的异步调度。阅读Benchmarking JAX Code以获取运行这类基准测试时的一些考虑事项。

考虑到这些要点,我可能会以以下方式重写您的代码:

from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import time

n = 10000

@partial(jax.jit, static_argnames=['n'])
def jax_mc(key, n):
  key1, key2 = jax.random.split(key)
  x = jax.random.normal(key1, shape=(n,))
  y = jax.random.normal(key2, shape=(n,))
  z = x**2 + y**2
  a = jnp.where(z <= 1, z, 0)
  count = jnp.count_nonzero(a)
  return 4 * count / n

key = jax.random.PRNGKey(0)
start = time.time()
_ = jax_mc(key, n)
end = time.time()
print('编译时间:', end - start)

start = time.time()
result = jax_mc(key, n).block_until_ready()
end = time.time()
print("运行时间:", end - start)
print("结果:", result)
编译时间: 0.8846290111541748
运行时间: 0.0018401145935058594
结果: 1.5516

请注意,我已将引号从 static_argnames 参数中移除,因为在 Python 中不需要引号。

英文:

There are a few issues here.

First, array sizes must be static in JAX (for some discussion, see JAX Sharp Bits: Dynamic Shapes) In your function n is dynamic, and so you cannot construct an array of shape n when the function is JIT-compiled.

Second, you're using numpy.random within a JIT-compiled function, which will most likely not behave as you expect. You should use jax.random instead; see JAX Sharp Bits: Random Numbers for more discussion of this.

Third, your timing is not taking into account JAX's asynchronous dispatch. Read Benchmarking JAX Code for some discussion of considerations when running these kinds of benchmarks.

With these points in mind,
I would probably rewrite your code this way:

from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import time

n = 10000

@partial(jax.jit, static_argnames=[&#39;n&#39;])
def jax_mc(key, n):
  key1, key2 = jax.random.split(key)
  x = jax.random.normal(key1, shape=(n,))
  y = jax.random.normal(key2, shape=(n,))
  z = x**2 + y**2
  a = jnp.where(z &lt;= 1, z,0)
  count = jnp.count_nonzero(a)
  return 4*count/n

key = jax.random.PRNGKey(0)
start = time.time()
_ = jax_mc(key, n)
end = time.time()
print(&#39;compilation time:&#39;, end - start)

start = time.time()
result = jax_mc(key, n).block_until_ready()
end = time.time()
print(&quot;runtime:&quot;, end - start)
print(&quot;result:&quot;, result)
compilation time: 0.8846290111541748
runtime: 0.0018401145935058594
result: 1.5516

huangapple
  • 本文由 发表于 2023年4月7日 00:21:00
  • 转载请务必保留本文链接:https://go.coder-hub.com/75951674.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定