英文:
How to vmap over cho_solve and cho_factor?
问题
以下是您要翻译的部分:
"因为以下代码的最后一行出现了以下错误:
> jax.errors.ConcretizationTypeError 遇到抽象的追踪值,但期望的是具体的值...
>
> 问题出现在 bool
函数中。
看起来是由于 cho_factor
返回的 lower
值引起的,_cho_solve
(注意下划线)需要这个值作为静态值。
我是 JAX 的新手,所以希望通过将 cho_factor
映射到 cho_solve
中来解决问题。我在这里做错了什么?
import jax
key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))
matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)
k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)
英文:
The following error appears because of the last line of code below:
> jax.errors.ConcretizationTypeError Abstract tracer value encountered where concrete value is expected...
>
> The problem arose with the bool
function.
It looks like it is due to the lower
return value from cho_factor
, which _cho_solve
(note underscore) requires as static.
I'm new to jax, so I was hoping that vmap-ing cho_factor
into cho_solve
would just work. What have I done wrong here?
import jax
key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))
matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)
k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)
答案1
得分: 1
问题是在每种情况下,lower
都是一个静态标量,不应该进行映射。因此,如果您指定 in_axes
和 out_axes
以便在轴 None
上映射 lower
,那么 vmap
应该工作:
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor, out_axes=(0, None))
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 0))
英文:
The issue is that in each case, lower
is a static scalar that should not be mapped over. So if you specify in_axes
and out_axes
so that lower
is mapped over axis None
, the vmap
should work:
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor, out_axes=(0, None))
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 0))
答案2
得分: 0
所以我没有成功地使cho_factor
和cho_solve
工作,但是通过使用cholesky
和solve_triangular
绕过了它:
cholesky = jax.vmap(jax.scipy.linalg.cholesky, in_axes=(0, None))
solve_tri = jax.vmap(jax.scipy.linalg.solve_triangular, in_axes=(0, 0, None, None))
L = cholesky(k_y, True)
result2 = solve_tri(L, solve_tri(L, y, 0, True), 1, True)
英文:
So I didn't manage to get cho_factor
and cho_solve
working, but worked around it using cholesky
and solve_triangular
:
cholesky = jax.vmap(jax.scipy.linalg.cholesky, in_axes=(0, None))
solve_tri = jax.vmap(jax.scipy.linalg.solve_triangular, in_axes=(0, 0, None, None))
L = cholesky(k_y, True)
result2 = solve_tri(L, solve_tri(L, y, 0, True), 1, True)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论