如何在cho_solve和cho_factor上使用vmap?

huangapple go评论49阅读模式
英文:

How to vmap over cho_solve and cho_factor?

问题

以下是您要翻译的部分:

"因为以下代码的最后一行出现了以下错误:

> jax.errors.ConcretizationTypeError 遇到抽象的追踪值,但期望的是具体的值...
>
> 问题出现在 bool 函数中。

看起来是由于 cho_factor 返回的 lower 值引起的,_cho_solve(注意下划线)需要这个值作为静态值。

我是 JAX 的新手,所以希望通过将 cho_factor 映射到 cho_solve 中来解决问题。我在这里做错了什么?

import jax

key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))

matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)

k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)
英文:

The following error appears because of the last line of code below:

> jax.errors.ConcretizationTypeError Abstract tracer value encountered where concrete value is expected...
>
> The problem arose with the bool function.

It looks like it is due to the lower return value from cho_factor, which _cho_solve (note underscore) requires as static.

I'm new to jax, so I was hoping that vmap-ing cho_factor into cho_solve would just work. What have I done wrong here?

import jax

key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))

matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)

k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)

答案1

得分: 1

问题是在每种情况下,lower 都是一个静态标量,不应该进行映射。因此,如果您指定 in_axesout_axes 以便在轴 None 上映射 lower,那么 vmap 应该工作:

cho_factor = jax.vmap(jax.scipy.linalg.cho_factor, out_axes=(0, None))
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 0))
英文:

The issue is that in each case, lower is a static scalar that should not be mapped over. So if you specify in_axes and out_axes so that lower is mapped over axis None, the vmap should work:

cho_factor = jax.vmap(jax.scipy.linalg.cho_factor, out_axes=(0, None))
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 0))

答案2

得分: 0

所以我没有成功地使cho_factorcho_solve工作,但是通过使用choleskysolve_triangular绕过了它:

  cholesky = jax.vmap(jax.scipy.linalg.cholesky, in_axes=(0, None))
  solve_tri = jax.vmap(jax.scipy.linalg.solve_triangular, in_axes=(0, 0, None, None))

  L = cholesky(k_y, True)
  result2 = solve_tri(L, solve_tri(L, y, 0, True), 1, True)
英文:

So I didn't manage to get cho_factor and cho_solve working, but worked around it using cholesky and solve_triangular:

  cholesky = jax.vmap(jax.scipy.linalg.cholesky, in_axes=(0, None))
  solve_tri = jax.vmap(jax.scipy.linalg.solve_triangular, in_axes=(0, 0, None, None))

  L = cholesky(k_y, True)
  result2 = solve_tri(L, solve_tri(L, y, 0, True), 1, True)

huangapple
  • 本文由 发表于 2023年6月13日 00:33:29
  • 转载请务必保留本文链接:https://go.coder-hub.com/76458629.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定