有办法在使用 jax.grad 计算梯度时接受一个函数吗?

huangapple go评论57阅读模式
英文:

Is there a way to accept a function while taking the gradient using jax.grad?

问题

我试图制作一个基于神经网络的微分方程求解器,用于求解微分方程 y' + 2xy = 0

import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

def softplus(x):
    return jnp.log(1 + jnp.exp(x))

def init_params():
    params = jax.random.normal(key, shape=(241,))
    return params

def linear_model(params, x):
    w0 = params[:80]
    b0 = params[80:160]
    w1 = params[160:240]
    b1 = params[240]
    h = softplus(x*w0 + b0)
    o = jnp.sum(h*w1) + b1
    return o

def loss(derivative, initial_condition, params, model, x):
    dfdx = jax.grad(model, 1)
    dfdx_vect = jax.vmap(dfdx, (None, 0))
    model_vect = jax.vmap(model, (None, 0))
    eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
    condition_difference = model(params, 0) - initial_condition
    return jnp.mean(eq_difference ** 2 - condition_difference ** 2)

def dfdx(x, y):
    return -2. * x * y

key = jax.random.PRNGKey(0)
inputs = np.linspace(0, 1, num=401)
params = init_params()

epochs = 2000
learning_rate = 0.0005

# Training Neural Network

for epoch in tqdm(range(epochs)):
    grad_loss = jax.grad(loss)
    gradient = grad_loss(dfdx, 1., params, linear_model, inputs)
    params -= learning_rate*gradient

model_vect = jax.vmap(linear_model, (None, 0))
preds = model_vect(params, inputs)

plt.plot(inputs, jnp.exp(inputs**2), label='exact')
plt.plot(inputs, model_vect(params, inputs), label='approx')
plt.legend()
plt.show()

问题是Jax不喜欢对接受另一个函数作为参数的函数取梯度:

TypeError: Argument '<function dfdx at 0x7fce88340af0>' of type <class 'function'> is not a valid JAX type.

有没有解决办法?

英文:

I am trying to make a neural network-based differential equation solver for the differential equation y' + 2xy = 0.

import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

def softplus(x):
    return jnp.log(1 + jnp.exp(x))

def init_params():
    params = jax.random.normal(key, shape=(241,))
    return params

def linear_model(params, x):
    w0 = params[:80]
    b0 = params[80:160]
    w1 = params[160:240]
    b1 = params[240]
    h = softplus(x*w0 + b0)
    o = jnp.sum(h*w1) + b1
    return o

def loss(derivative, initial_condition, params, model, x):
    dfdx = jax.grad(model, 1)
    dfdx_vect = jax.vmap(dfdx, (None, 0))
    model_vect = jax.vmap(model, (None, 0))
    eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
    condition_difference = model(params, 0) - initial_condition
    return jnp.mean(eq_difference ** 2 - condition_difference ** 2)

def dfdx(x, y):
    return -2. * x * y

key = jax.random.PRNGKey(0)
inputs = np.linspace(0, 1, num=401)
params = init_params()

epochs = 2000
learning_rate = 0.0005

# Training Neural Network

for epoch in tqdm(range(epochs)):
    grad_loss = jax.grad(loss)
    gradient = grad_loss(dfdx, 1., params, linear_model, inputs)
    params -= learning_rate*gradient

model_vect = jax.vmap(linear_model, (None, 0))
preds = model_vect(params, inputs)

plt.plot(inputs, jnp.exp(inputs**2), label='exact')
plt.plot(inputs, model_vect(params, inputs), label='approx')
plt.legend()
plt.show()

The issue is that Jax doesn't like taking the gradient of a function that receives another function as an argument:

TypeError: Argument '<function dfdx at 0x7fce88340af0>' of type <class 'function'> is not a valid JAX type.

Is there any workaround for this?

答案1

得分: 1

你刚刚错误地排序了参数。Jax根据第一个参数进行区分,而你不希望根据函数区分,而是根据参数。只需将它们作为第一个参数。

英文:

You just orderd arguments wrong. Jax differentiates wrt. first argument, and you don't want to differentiate wrt your function, but rather - parameters. Just make them the first argument.

def loss(params, derivative, initial_condition, model, x):
    dfdx = jax.grad(model, 1)
    dfdx_vect = jax.vmap(dfdx, (None, 0))
    model_vect = jax.vmap(model, (None, 0))
    eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
    condition_difference = model(params, 0) - initial_condition
    return jnp.mean(eq_difference ** 2 - condition_difference ** 2)

huangapple
  • 本文由 发表于 2023年4月17日 02:52:40
  • 转载请务必保留本文链接:https://go.coder-hub.com/76029743.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定