如何在这个例子中使用和解释JAX的矢量-雅可比积(VJP)?

huangapple go评论60阅读模式
英文:

How to use and interpret JAX Vector-Jacobian Product (VJP) for this example?

问题

我在试图学习如何使用 JAX 找到矢量值 ODE 函数的 Jacobian。我使用了 https://implicit-layers-tutorial.org/implicit_functions/ 上的示例。该页面实现了自己的 ODE 积分器以及相关的自定义 forward-mode 和 reverse-mode Jacobian 函数。我正在尝试使用官方的 jax odeint 和 diffrax 库来复现它,但这两者主要使用反向模式 Vector Jacobian Product (VJP),而不是该页面上提供的正向模式 Jacobian Vector Product (JVP) 的示例代码。

这是我从该页面调整的代码片段:

# ...(你的代码)

# 现在尝试进行反向模式向量雅可比乘积(VJP),因为jax-odeint没有定义正向模式 JVP
vjp_ys, vjp_evolve = vjp(evolve, y0, rho, sigma, beta)

# vjp_ys 和 ys 相等,它们都是解的时间序列,包含 y 的 3 个分量(状态变量)
print(jnp.array_equal(ys, vjp_ys))

# 定义一些在 y0 和参数中的扰动
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.

# 这里失败了
# vjp_evolve 是一个函数,但我不确定如何使用它来获得在 y0/参数变化时的扰动 delta_ys
vjp_evolve(delta_y0, delta_rho, delta_sigma, delta_beta)

最后一行引发错误:

TypeError: 'jax.vjp' 返回的应用于 evolve 的函数被调用了 4 个参数'jax.vjp' 返回的函数必须用一个参数调用该参数对应于 evolve 返回的单个值即使该返回值是元组或其他容器)。...

我怀疑我对反向模式 VJP 的概念以及在这种矢量值 ODE 情况下输入是什么感到困惑。如果我使用 diffrax 求解器,相同的问题可能会存在。

有趣的是,如果我在使用 diffrax 求解器时指定 adjoint=NoAdjoint,我可以在该网站上使用 diffrax 求解器重现正向模式 JVP 的结果:

# ...(你的代码)

# 我同样对如何在 diffrax 的默认反向模式 ODE 系统的自动微分中使用 VJP 感到困惑
# 但是,如果我指定 adjoint=NoAdjoint,我就可以在 diffrax 的 ODE 求解器中使用正向模式 JVP

如何在这个例子中使用和解释JAX的矢量-雅可比积(VJP)?

这重现了该网站的主要图表之一(显示 ODE 对 beta 参数的变化非常敏感)。我理解正向模式 JVP 的概念(在初始条件和/或参数的扰动下,JVP 给出随时间变化的 ODE 解的相应扰动)。但是反向模式 VJP 是什么,以及在上述 vjp_evolve 函数中正确的输入是什么?
1: https://i.stack.imgur.com/QXcDe.png

英文:

I am trying to learn how to find the Jacobian of a vector-valued ODE function using JAX. I am using the examples at https://implicit-layers-tutorial.org/implicit_functions/ That page implements its own ODE integrator and associated custom forward-mode and reverse-mode Jacobian functions. I am trying to reproduce that using the official jax odeint and diffrax libraries, but both of these primarily use reverse-mode Vector Jacobian Product (VJP) instead of the forward-mode Jacobian Vector Product (JVP) for which example code is available on that page.

Here is a code snippet that I adapted from that page:

import matplotlib.pyplot as plt

from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import jit, jvp, vjp
from jax.experimental.ode import odeint

from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Dopri5, NoAdjoint

# returns time derivatives of each of our 3 state variables (vector-valued function)
def f(state, t, args):
    x, y, z = state
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

# convenience function that calls jax-odeint given input initial conditions and parameters (this is the function that we want Jacobian/sensitivities of)
def evolve(y0, rho, sigma, beta): 
    return odeint(f, y0, tarr, (rho, sigma, beta))


# set up initial conditions, timespan for integration, and fiducial parameter values
y0 = jnp.array([5., 5., 5.])
tarr = jnp.linspace(0, 1., 1000)
rho = 28.
sigma = 10.
beta = 8/3. 


# first just make sure evolve() works 
ys = evolve(y0, rho, sigma, beta)

fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})   
ax.plot(ys.T[0],ys.T[1],ys.T[2],'b-',lw=0.5)

# now try to take reverse-mode vector-jacobian product (VJP) since forward-mode JVP is not defined for jax-odeint
vjp_ys, vjp_evolve = vjp(evolve,y0,rho,sigma,beta)

# vjp_ys and ys are equal -- they are the solution time series of the 3 components (state variables) of y 
print(jnp.array_equal(ys,vjp_ys))

# define some perturbation in y0 and parameters 
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.

####### THIS FAILS 
# vjp_evolve is a function but I am not sure how to use it to get perturbations delta_ys given y0/parameter variations
vjp_evolve(delta_y0,delta_rho,delta_sigma,delta_beta)

That last line raises an error:

TypeError: The function returned by `jax.vjp` applied to evolve was called with 4 arguments, but functions returned by `jax.vjp` must be called with a single argument corresponding to the single value returned by evolve (even if that returned value is a tuple or other container).

For example, if we have:

  def f(x):
    return (x, x)
  _, f_vjp = jax.vjp(f, 1.0)

the function `f` returns a single tuple as output, and so we call `f_vjp` with a single tuple as its argument:

  x_bar, = f_vjp((2.0, 2.0))

If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted out' as arguments rather than in a tuple, this error can arise.

I suspect I am confused at the concept of reverse-mode VJP and what the input would be in the case of this vector-valued ODE. The same problem would persist if I had used diffrax solvers.

For what it's worth, I can reproduce the forward-mode JVP results on that website if I use a diffrax solver while specifying adjoint=NoAdjoint, so that jax.jvp can be used:

# I am similarly confused about how to use VJP with diffrax's default reverse-mode autodiff of the ODE system
# however I am able to use forward-mode JVP with diffrax's ODE solver if I specify adjoint=NoAdjoint

# diffrax expects reverse order for inputs (time first, then state, then args) -- opposite of jax odeint 
def f_diffrax(t, state, args):
    x, y, z = state
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

# set up diffrax inputs as closely to jax-odeint as possible 
terms = ODETerm(f_diffrax)
t0 = 0.0
t1 = 1.0 
dt0 = None
max_steps = 16**3 # not sure if this is needed
tsave = SaveAt(ts=tarr,dense=True)

def evolve_diffrax(y0, rho, sigma, beta):
    return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),saveat=tsave,
                       stepsize_controller=PIDController(rtol=1.4e-8,atol=1.4e-8),max_steps=max_steps,adjoint=NoAdjoint())

# get solution AND differentials assuming the same changes in y0 and parameters as we tried (and failed) to get above 
diffrax_ys, diffrax_delta_ys = jvp(evolve_diffrax, (y0,rho,sigma,beta),(delta_y0,delta_rho,delta_sigma,delta_beta))

# get the actual solution arrays from the diffrax Solution objects 
diffrax_ys = diffrax_ys.ys
diffrax_delta_ys = diffrax_delta_ys.ys

# plot 
fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})   
ax.plot(diffrax_ys.T[0],diffrax_ys.T[1],diffrax_ys.T[2],color='violet',lw=0.5)
ax.quiver(diffrax_ys.T[0][::10],diffrax_ys.T[1][::10],diffrax_ys.T[2][::10],
          diffrax_delta_ys.T[0][::10],diffrax_delta_ys.T[1][::10],diffrax_delta_ys.T[2][::10])
    

如何在这个例子中使用和解释JAX的矢量-雅可比积(VJP)?

That reproduces one of the main plots of that website (showing that the ODE is very sensitive to variations in the beta parameter). So I understand the concept of forward-mode JVP (given perturbations in initial conditions and/or parameters, JVP gives the corresponding perturbation in the ODE solution as a function of time). But what does reverse-mode VJP do and what would be the correct input to the vjp_evolve function above?

答案1

得分: 1

JVP 是正向模式自动微分:给定在原始点上函数输入的切线,它返回在输出上的切线。

VJP 是反向模式自动微分:给定在原始点上函数输出的余切线,它返回在输入上的余切线。

因此,您可以使用与 vjp_ys 相同形状的余切线调用 vjp_evolve

print(vjp_evolve(jnp.ones_like(vjp_ys)))
(Array([ 1.74762118, 26.45747015, -2.03017559], dtype=float64),
 Array(871.66349663, dtype=float64),
 Array(-83.07586548, dtype=float64),
 Array(-1754.48788565, dtype=float64))

从概念上讲,JVP 通过计算向前传播梯度,而 VJP 通过向后传播梯度。
JAX 文档可能对更深入理解 JVP 和 VJP 转换有帮助:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff

英文:

JVP is forward-mode autodiff: given tangents of the input to the function at a primal point, it returns tangents on the outputs.

VJP is reverse-mode autodiff: given cotangents on the output of the function at a primal point, it returns cotangents on the inputs.

So you can call vjp_evolve with cotangents of the same shape as vjp_ys:

print(vjp_evolve(jnp.ones_like(vjp_ys)))
(Array([ 1.74762118, 26.45747015, -2.03017559], dtype=float64),
 Array(871.66349663, dtype=float64),
 Array(-83.07586548, dtype=float64),
 Array(-1754.48788565, dtype=float64))

Conceptually, JVP propagates gradients forward through a computation, while VJP propagates gradients backward.
The JAX docs might be useful background for understanding the JVP & VJP transformations more deeply: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff

huangapple
  • 本文由 发表于 2023年3月12日 13:59:59
  • 转载请务必保留本文链接:https://go.coder-hub.com/75711315.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定