JAX 0.2.17和JAX 0.4.1之间的内存需求巨大差异

huangapple go评论63阅读模式
英文:

Huge memory requirement difference between JAX 0.2.17 and JAX 0.4.1

问题

在接受答案中使用函数时,无论是否使用jax(bar或jit_bar):

T = np.random.rand(5000, 566, 3)
@jax.jit
def jit_bar(Y):
   u, v = jnp.triu_indices(Y.shape[0], 1)
   return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
msd = jit_bar(T)

将一个 (10000x566x3) 的数组传递给该函数,在 python3.6 下会给我稳定的 1.5 GB 内存使用量,但在 python = 3.8, 3.9, 3.10, 3.11 下内存飙升到 +50 GB。

编辑:

经过一些尝试,似乎与仅与jax有关,此代码将在以下环境中正常运行:

python3.6, jax (0.2.17), jaxlib (0.1.68), numpy (1.19.2)

但在以下环境中不会正常运行:

python3.11, jax (0.4.1), jaxlib (0.4.1), numpy (1.24.1)

英文:

Follow up of the question:

https://stackoverflow.com/questions/74635970/is-it-possible-to-improve-python-performance-for-this-code/74636106?noredirect=1#comment131773811_74636106

When using the functions from the accepted answer, with or without jax (bar or jit_bar):

T = np.random.rand(5000, 566, 3)
@jax.jit
def jit_bar(Y):
   u, v = jnp.triu_indices(Y.shape[0], 1)
   return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
msd = jit_bar(T)

Sending a (10000x566x3) array to the function give me a stable memory usage of 1.5 GB with python3.6, with python = 3.8, 3.9, 3.10, 3.11 the memory skyrocket to +50 GB.

EDIT:

After some trials it seems to be related to jax only, this code will run fine with:

python3.6, jax (0.2.17), jaxlib (0.1.68), numpy (1.19.2)

but not with:

python3.11, jax (0.4.1), jaxlib (0.4.1), numpy (1.24.1)

答案1

得分: 3

如果 Y 的形状是 (10000, 566, 3),那么 triu_indices 返回的数组长度是 (10000 * 10001) / 2,因此 Y[u]Y[v] 每个的大小都是 (50005000, 566, 3)。如果它们是 float32 值,那么每个的大小大约为 316 GB。我不会期望这段代码在任何地方都能运行良好!

我怀疑较旧的 JAX 版本可能具有一些后来删除的附加优化;根据你的计算形式,唯一可能的是对方差的分解,以避免实例化完整的矩阵和总和,我模糊记得以前是 XLA 的优化,但因为数值不稳定而被删除。

但如果你愿意,你可以手动进行这样的优化;以下是一个看似有效的方法,它为原始输入生成的最大中间数组的形状为 [10000, 10000],大约为 ~380MB(float32):

@jax.jit
def jit_bar2(Y):
   u, v = jnp.triu_indices(Y.shape[0], 1)
   Y = Y.reshape(Y.shape[0], -1)
   Y2m = (Y ** 2).mean(-1)
   YYTm = (Y @ Y.T) / Y.shape[1]
   return jnp.sqrt(3 * (Y2m[u] + Y2m[v] - 2 * YYTm[u, v]))

T = np.random.rand(50, 6, 3)  # 使用较小的输入进行测试
np.testing.assert_allclose(jit_bar(T), jit_bar2(T), atol=1E-5)

请注意,这只是代码的翻译部分,没有其他内容。

英文:

If Y is of shape (10000, 566, 3) Then triu_indices returns arrays of length (10000 * 10001) / 2, and so Y[u] and Y[v] are each of size (50005000, 566, 3). If they are float32 values, then that size is about 316 GB each. I would not expect this code to run well anywhere!

I suspect that older JAX versions may have had some additional optimization that was removed in later versions; given the form of your computation, the only thing that could have been is a factorization of the square difference to avoid instantiating the full matrix sum, which I vaguely recall was previously an XLA optimization but was removed because it's numerically unstable.

But you can do such an optimization manually if you wish; here's an approach that seems to work, and the largest intermediate array it generates for the original inputs is of shape [10000, 10000], about ~380MB in float32:

@jax.jit
def jit_bar2(Y):
   u, v = jnp.triu_indices(Y.shape[0], 1)
   Y = Y.reshape(Y.shape[0], -1)
   Y2m = (Y ** 2).mean(-1)
   YYTm = (Y @ Y.T) / Y.shape[1]
   return jnp.sqrt(3 * (Y2m[u] + Y2m[v] - 2 * YYTm[u, v]))

T = np.random.rand(50, 6, 3)  # test with a smaller input
np.testing.assert_allclose(jit_bar(T), jit_bar2(T), atol=1E-5)

huangapple
  • 本文由 发表于 2023年2月16日 18:59:39
  • 转载请务必保留本文链接:https://go.coder-hub.com/75471289.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定