你如何在Jax中实现动态范围上的可映射求和?

huangapple go评论54阅读模式
英文:

How can I implement a vmappable sum over a dynamic range in Jax?

问题

Sure, here's the translated code portion:

要在Jax中实现类似以下Python函数的内容并用`vmap`调用它我希望它在vmap之后仍然可以完全**反向模式**可微分关于`x`),使用`grad()`,即使kmax

```python
def f(x,kmax):
  return sum ([x**k for k in range(1,kmax+1)])

(这是函数的故意简化版本;我明白在这种情况下我可以使用几何级数的封闭形式表达式;不幸的是,我正在尝试实现的实际函数没有我知道的封闭形式总和。)

有没有办法做到这一点?看起来一定要有;但是,如果kmax是动态的,fori_loop不是反向模式可微分的,jax.lax.scan需要一个静态形状的数组,否则它会引发ConcretizationTypeError,类似的Python原语如range(如上所示)如果包装在vmap中,则会引发TracerIntegerConversionError

我想我了解了需要数组具有固定形状的限制,但是我使用过的每个自动微分框架都允许您以某种方式动态构建任意大小的表达式。对可变整数范围的总和是一种相当基本的数学工具。如何在Jax中实现这一点?

编辑以重新聚焦问题定义(问题主要是vmap而不是grad),并提供以下示例。

这是我想要能够执行的内容

import jax

def f(x,kmax):
  return sum ([x**k for k in range(1,kmax+1)])

fmap = jax.vmap(f,in_axes=(None,-1))

x = 3.
kmaxes = jax.numpy.array([1,2,3])

print(fmap(x,kmaxes))

fmap_sum = lambda k,kmaxes:jax.numpy.sum(fmap(k,kmaxes))

print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))

这在range(1,kmax+1)处引发了TracerIntegerConversionError。我希望它能够执行类似于这样的操作:

import jax

def f(x,kmax):
  return sum ([x**k for k in range(1,kmax+1)])

def fmap(x,kmaxes):
  return [f(x,kmax) for kmax in kmaxes]

x = 3.
kmaxes = jax.numpy.array([1,2,3])

print(fmap(x,kmaxes))

def fmap_sum(x,kmaxes):
  return sum(fmap(x,kmaxes))

print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))

这会给出正确的结果(但会失去vmap的并行化和加速功能)。


Hope this helps!

<details>
<summary>英文:</summary>

I want to implement something like the following Python function in Jax, and wrap it with a call to `vmap`. I want it to be fully **reverse-mode** differentiable (with respect to `x`) using `grad()`, even after the vmap.

def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])


(This is a deliberately simplified version of the function; I realize in this case I could use the closed-form expression for the geometric series; sadly the actual function I&#39;m trying to implement does not have a closed-form sum that I&#39;m aware of.)

Is there any way to do this? It seems like there _has_ to be; but `fori_loop` is not reverse-mode differentiable if `kmax` is dynamic, `jax.lax.scan` needs a statically-shaped array or it will throw `ConcretizationTypeError`s, and similarly Python primitives like `range` (as used above) throw `TracerIntegerConversionError` if wrapped in `vmap`.

I think I understand the restrictions on needing arrays to be fixed-shape, but every autodiff framework I&#39;ve ever used allows you to construct arbitrarily-sized expressions dynamically *somehow*. A sum over a varying integer range is a pretty basic mathematical tool. How does one implement this in Jax?

EDITED to refocus the problem definition (the issue is more vmap than grad) and provide the following examples.

This, specifically, is what I&#39;d like to be able to do

import jax

def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])

fmap = jax.vmap(f,in_axes=(None,-1))

x = 3.
kmaxes = jax.numpy.array([1,2,3])

print(fmap(x,kmaxes))

fmap_sum = lambda k,kmaxes:jax.numpy.sum(fmap(k,kmaxes))

print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))


This throws a TracerIntegerConversionError at `range(1,kmax+1)`.
What I would like it to be doing is something like this:

import jax

def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])

def fmap(x,kmaxes):
return [f(x,kmax) for kmax in kmaxes]

x = 3.
kmaxes = jax.numpy.array([1,2,3])

print(fmap(x,kmaxes))

def fmap_sum(x,kmaxes):
return sum(fmap(x,kmaxes))

print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))


which gives the correct result (but loses the parallelization and acceleration of vmap).


</details>


# 答案1
**得分**: 1

Sure, here is the translation of the code you provided:

首先,为了使您的函数与 `vmap` 兼容,您需要用 [`jax.lax` 控制流](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow) 操作替换 Python 控制流。在这种情况下,`lax.fori_loop` 似乎是适用的:

```python
def f1(x, k):
  def body_fun(i, val):
    return val + x ** i
  return jax.lax.fori_loop(1, k + 1, body_fun, jnp.zeros_like(x))

f1map = jax.vmap(f1, (None, 0))
print(f1map(x, kmaxes))
# [ 3. 12. 39.]

但由于循环的大小是动态的,这不与反向模式自动微分兼容:

jax.jacrev(f1map)(x, kmaxes)
# ValueError: 反向模式微分不适用于 lax.while_loop 或 lax.fori_loop。尝试使用 lax.scan。

为了解决这个问题,您可以修改您的函数,使其使用静态循环大小。下面是一种可能的方式:

def f2(x, k, kmax):  # kmax 应该是静态的
  def body_fun(i, val):
    return val + jnp.where(i <= k, x ** i, 0)
  return jax.lax.fori_loop(1, kmax + 1, body_fun, jnp.zeros_like(x))

f2map = jax.vmap(f2, (None, 0, None))

print(f2map(x, kmaxes, kmaxes.max()))  # 与 vmap 兼容
# [ 3. 12. 39.]

print(jax.jacrev(f2map)(x, kmaxes, kmaxes.max()))  # 以及与反向模式自动微分兼容
# [ 1.  7. 34.]

Please note that the code includes links to external resources and specific Python functions, which are not translated.

英文:

First, to make your function compatible with vmap, you'll need to replace the Python control flow with jax.lax control flow operations. In this case, lax.fori_loop seems applicable:

def f1(x, k):
  def body_fun(i, val):
    return val + x ** i
  return jax.lax.fori_loop(1, k + 1, body_fun, jnp.zeros_like(x))

f1map = jax.vmap(f1, (None, 0))
print(f1map(x, kmaxes))
# [ 3. 12. 39.]

But because the size of the loop is dynamic, this is not compatible with reverse-mode autodiff:

jax.jacrev(f1map)(x, kmaxes)
# ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

To get around this, you can modify your function such that it uses a static loop size. Here's one way you might do that:

def f2(x, k, kmax):  # kmax should be static
  def body_fun(i, val):
    return val + jnp.where(i &lt;= k, x ** i, 0)
  return jax.lax.fori_loop(1, kmax + 1, body_fun, jnp.zeros_like(x))

f2map = jax.vmap(f2, (None, 0, None))

print(f2map(x, kmaxes, kmaxes.max()))  # compatible with vmap
# [ 3. 12. 39.]

print(jax.jacrev(f2map)(x, kmaxes, kmaxes.max()))  # and with reverse-mode autodiff
# [ 1.  7. 34.]

huangapple
  • 本文由 发表于 2023年5月25日 23:57:25
  • 转载请务必保留本文链接:https://go.coder-hub.com/76334231.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定