英文:
How to use and interpret JAX Vector-Jacobian Product (VJP) for this example?
问题
我在试图学习如何使用 JAX 找到矢量值 ODE 函数的 Jacobian。我使用了 https://implicit-layers-tutorial.org/implicit_functions/ 上的示例。该页面实现了自己的 ODE 积分器以及相关的自定义 forward-mode 和 reverse-mode Jacobian 函数。我正在尝试使用官方的 jax odeint 和 diffrax 库来复现它,但这两者主要使用反向模式 Vector Jacobian Product (VJP),而不是该页面上提供的正向模式 Jacobian Vector Product (JVP) 的示例代码。
这是我从该页面调整的代码片段:
# ...(你的代码)
# 现在尝试进行反向模式向量雅可比乘积(VJP),因为jax-odeint没有定义正向模式 JVP
vjp_ys, vjp_evolve = vjp(evolve, y0, rho, sigma, beta)
# vjp_ys 和 ys 相等,它们都是解的时间序列,包含 y 的 3 个分量(状态变量)
print(jnp.array_equal(ys, vjp_ys))
# 定义一些在 y0 和参数中的扰动
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.
# 这里失败了
# vjp_evolve 是一个函数,但我不确定如何使用它来获得在 y0/参数变化时的扰动 delta_ys
vjp_evolve(delta_y0, delta_rho, delta_sigma, delta_beta)
最后一行引发错误:
TypeError: 'jax.vjp' 返回的应用于 evolve 的函数被调用了 4 个参数,但 'jax.vjp' 返回的函数必须用一个参数调用,该参数对应于 evolve 返回的单个值(即使该返回值是元组或其他容器)。...
我怀疑我对反向模式 VJP 的概念以及在这种矢量值 ODE 情况下输入是什么感到困惑。如果我使用 diffrax 求解器,相同的问题可能会存在。
有趣的是,如果我在使用 diffrax 求解器时指定 adjoint=NoAdjoint,我可以在该网站上使用 diffrax 求解器重现正向模式 JVP 的结果:
# ...(你的代码)
# 我同样对如何在 diffrax 的默认反向模式 ODE 系统的自动微分中使用 VJP 感到困惑
# 但是,如果我指定 adjoint=NoAdjoint,我就可以在 diffrax 的 ODE 求解器中使用正向模式 JVP
这重现了该网站的主要图表之一(显示 ODE 对 beta 参数的变化非常敏感)。我理解正向模式 JVP 的概念(在初始条件和/或参数的扰动下,JVP 给出随时间变化的 ODE 解的相应扰动)。但是反向模式 VJP 是什么,以及在上述 vjp_evolve 函数中正确的输入是什么?
1: https://i.stack.imgur.com/QXcDe.png
英文:
I am trying to learn how to find the Jacobian of a vector-valued ODE function using JAX. I am using the examples at https://implicit-layers-tutorial.org/implicit_functions/ That page implements its own ODE integrator and associated custom forward-mode and reverse-mode Jacobian functions. I am trying to reproduce that using the official jax odeint and diffrax libraries, but both of these primarily use reverse-mode Vector Jacobian Product (VJP) instead of the forward-mode Jacobian Vector Product (JVP) for which example code is available on that page.
Here is a code snippet that I adapted from that page:
import matplotlib.pyplot as plt
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, jvp, vjp
from jax.experimental.ode import odeint
from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Dopri5, NoAdjoint
# returns time derivatives of each of our 3 state variables (vector-valued function)
def f(state, t, args):
x, y, z = state
rho, sigma, beta = args
return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])
# convenience function that calls jax-odeint given input initial conditions and parameters (this is the function that we want Jacobian/sensitivities of)
def evolve(y0, rho, sigma, beta):
return odeint(f, y0, tarr, (rho, sigma, beta))
# set up initial conditions, timespan for integration, and fiducial parameter values
y0 = jnp.array([5., 5., 5.])
tarr = jnp.linspace(0, 1., 1000)
rho = 28.
sigma = 10.
beta = 8/3.
# first just make sure evolve() works
ys = evolve(y0, rho, sigma, beta)
fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})
ax.plot(ys.T[0],ys.T[1],ys.T[2],'b-',lw=0.5)
# now try to take reverse-mode vector-jacobian product (VJP) since forward-mode JVP is not defined for jax-odeint
vjp_ys, vjp_evolve = vjp(evolve,y0,rho,sigma,beta)
# vjp_ys and ys are equal -- they are the solution time series of the 3 components (state variables) of y
print(jnp.array_equal(ys,vjp_ys))
# define some perturbation in y0 and parameters
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.
####### THIS FAILS
# vjp_evolve is a function but I am not sure how to use it to get perturbations delta_ys given y0/parameter variations
vjp_evolve(delta_y0,delta_rho,delta_sigma,delta_beta)
That last line raises an error:
TypeError: The function returned by `jax.vjp` applied to evolve was called with 4 arguments, but functions returned by `jax.vjp` must be called with a single argument corresponding to the single value returned by evolve (even if that returned value is a tuple or other container).
For example, if we have:
def f(x):
return (x, x)
_, f_vjp = jax.vjp(f, 1.0)
the function `f` returns a single tuple as output, and so we call `f_vjp` with a single tuple as its argument:
x_bar, = f_vjp((2.0, 2.0))
If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted out' as arguments rather than in a tuple, this error can arise.
I suspect I am confused at the concept of reverse-mode VJP and what the input would be in the case of this vector-valued ODE. The same problem would persist if I had used diffrax solvers.
For what it's worth, I can reproduce the forward-mode JVP results on that website if I use a diffrax solver while specifying adjoint=NoAdjoint, so that jax.jvp can be used:
# I am similarly confused about how to use VJP with diffrax's default reverse-mode autodiff of the ODE system
# however I am able to use forward-mode JVP with diffrax's ODE solver if I specify adjoint=NoAdjoint
# diffrax expects reverse order for inputs (time first, then state, then args) -- opposite of jax odeint
def f_diffrax(t, state, args):
x, y, z = state
rho, sigma, beta = args
return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])
# set up diffrax inputs as closely to jax-odeint as possible
terms = ODETerm(f_diffrax)
t0 = 0.0
t1 = 1.0
dt0 = None
max_steps = 16**3 # not sure if this is needed
tsave = SaveAt(ts=tarr,dense=True)
def evolve_diffrax(y0, rho, sigma, beta):
return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),saveat=tsave,
stepsize_controller=PIDController(rtol=1.4e-8,atol=1.4e-8),max_steps=max_steps,adjoint=NoAdjoint())
# get solution AND differentials assuming the same changes in y0 and parameters as we tried (and failed) to get above
diffrax_ys, diffrax_delta_ys = jvp(evolve_diffrax, (y0,rho,sigma,beta),(delta_y0,delta_rho,delta_sigma,delta_beta))
# get the actual solution arrays from the diffrax Solution objects
diffrax_ys = diffrax_ys.ys
diffrax_delta_ys = diffrax_delta_ys.ys
# plot
fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})
ax.plot(diffrax_ys.T[0],diffrax_ys.T[1],diffrax_ys.T[2],color='violet',lw=0.5)
ax.quiver(diffrax_ys.T[0][::10],diffrax_ys.T[1][::10],diffrax_ys.T[2][::10],
diffrax_delta_ys.T[0][::10],diffrax_delta_ys.T[1][::10],diffrax_delta_ys.T[2][::10])
That reproduces one of the main plots of that website (showing that the ODE is very sensitive to variations in the beta parameter). So I understand the concept of forward-mode JVP (given perturbations in initial conditions and/or parameters, JVP gives the corresponding perturbation in the ODE solution as a function of time). But what does reverse-mode VJP do and what would be the correct input to the vjp_evolve function above?
答案1
得分: 1
JVP 是正向模式自动微分:给定在原始点上函数输入的切线,它返回在输出上的切线。
VJP 是反向模式自动微分:给定在原始点上函数输出的余切线,它返回在输入上的余切线。
因此,您可以使用与 vjp_ys
相同形状的余切线调用 vjp_evolve
:
print(vjp_evolve(jnp.ones_like(vjp_ys)))
(Array([ 1.74762118, 26.45747015, -2.03017559], dtype=float64),
Array(871.66349663, dtype=float64),
Array(-83.07586548, dtype=float64),
Array(-1754.48788565, dtype=float64))
从概念上讲,JVP 通过计算向前传播梯度,而 VJP 通过向后传播梯度。
JAX 文档可能对更深入理解 JVP 和 VJP 转换有帮助:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff
英文:
JVP is forward-mode autodiff: given tangents of the input to the function at a primal point, it returns tangents on the outputs.
VJP is reverse-mode autodiff: given cotangents on the output of the function at a primal point, it returns cotangents on the inputs.
So you can call vjp_evolve
with cotangents of the same shape as vjp_ys
:
print(vjp_evolve(jnp.ones_like(vjp_ys)))
(Array([ 1.74762118, 26.45747015, -2.03017559], dtype=float64),
Array(871.66349663, dtype=float64),
Array(-83.07586548, dtype=float64),
Array(-1754.48788565, dtype=float64))
Conceptually, JVP propagates gradients forward through a computation, while VJP propagates gradients backward.
The JAX docs might be useful background for understanding the JVP & VJP transformations more deeply: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论