英文:
Huge memory requirement difference between JAX 0.2.17 and JAX 0.4.1
问题
在接受答案中使用函数时,无论是否使用jax(bar或jit_bar):
T = np.random.rand(5000, 566, 3)
@jax.jit
def jit_bar(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
msd = jit_bar(T)
将一个 (10000x566x3) 的数组传递给该函数,在 python3.6 下会给我稳定的 1.5 GB 内存使用量,但在 python = 3.8, 3.9, 3.10, 3.11 下内存飙升到 +50 GB。
编辑:
经过一些尝试,似乎与仅与jax有关,此代码将在以下环境中正常运行:
python3.6, jax (0.2.17), jaxlib (0.1.68), numpy (1.19.2)
但在以下环境中不会正常运行:
python3.11, jax (0.4.1), jaxlib (0.4.1), numpy (1.24.1)
英文:
Follow up of the question:
When using the functions from the accepted answer, with or without jax (bar or jit_bar):
T = np.random.rand(5000, 566, 3)
@jax.jit
def jit_bar(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
msd = jit_bar(T)
Sending a (10000x566x3) array to the function give me a stable memory usage of 1.5 GB with python3.6, with python = 3.8, 3.9, 3.10, 3.11 the memory skyrocket to +50 GB.
EDIT:
After some trials it seems to be related to jax only, this code will run fine with:
python3.6, jax (0.2.17), jaxlib (0.1.68), numpy (1.19.2)
but not with:
python3.11, jax (0.4.1), jaxlib (0.4.1), numpy (1.24.1)
答案1
得分: 3
如果 Y
的形状是 (10000, 566, 3)
,那么 triu_indices
返回的数组长度是 (10000 * 10001) / 2
,因此 Y[u]
和 Y[v]
每个的大小都是 (50005000, 566, 3)
。如果它们是 float32
值,那么每个的大小大约为 316 GB。我不会期望这段代码在任何地方都能运行良好!
我怀疑较旧的 JAX 版本可能具有一些后来删除的附加优化;根据你的计算形式,唯一可能的是对方差的分解,以避免实例化完整的矩阵和总和,我模糊记得以前是 XLA 的优化,但因为数值不稳定而被删除。
但如果你愿意,你可以手动进行这样的优化;以下是一个看似有效的方法,它为原始输入生成的最大中间数组的形状为 [10000, 10000]
,大约为 ~380MB(float32):
@jax.jit
def jit_bar2(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
Y = Y.reshape(Y.shape[0], -1)
Y2m = (Y ** 2).mean(-1)
YYTm = (Y @ Y.T) / Y.shape[1]
return jnp.sqrt(3 * (Y2m[u] + Y2m[v] - 2 * YYTm[u, v]))
T = np.random.rand(50, 6, 3) # 使用较小的输入进行测试
np.testing.assert_allclose(jit_bar(T), jit_bar2(T), atol=1E-5)
请注意,这只是代码的翻译部分,没有其他内容。
英文:
If Y
is of shape (10000, 566, 3)
Then triu_indices
returns arrays of length (10000 * 10001) / 2
, and so Y[u]
and Y[v]
are each of size (50005000, 566, 3)
. If they are float32
values, then that size is about 316 GB each. I would not expect this code to run well anywhere!
I suspect that older JAX versions may have had some additional optimization that was removed in later versions; given the form of your computation, the only thing that could have been is a factorization of the square difference to avoid instantiating the full matrix sum, which I vaguely recall was previously an XLA optimization but was removed because it's numerically unstable.
But you can do such an optimization manually if you wish; here's an approach that seems to work, and the largest intermediate array it generates for the original inputs is of shape [10000, 10000]
, about ~380MB in float32:
@jax.jit
def jit_bar2(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
Y = Y.reshape(Y.shape[0], -1)
Y2m = (Y ** 2).mean(-1)
YYTm = (Y @ Y.T) / Y.shape[1]
return jnp.sqrt(3 * (Y2m[u] + Y2m[v] - 2 * YYTm[u, v]))
T = np.random.rand(50, 6, 3) # test with a smaller input
np.testing.assert_allclose(jit_bar(T), jit_bar2(T), atol=1E-5)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论