有办法在使用 jax.grad 计算梯度时接受一个函数吗?

huangapple go评论75阅读模式
英文:

Is there a way to accept a function while taking the gradient using jax.grad?

问题

我试图制作一个基于神经网络的微分方程求解器,用于求解微分方程 y' + 2xy = 0

  1. import jax.numpy as jnp
  2. import jax
  3. import matplotlib.pyplot as plt
  4. from tqdm import tqdm
  5. import numpy as np
  6. def softplus(x):
  7. return jnp.log(1 + jnp.exp(x))
  8. def init_params():
  9. params = jax.random.normal(key, shape=(241,))
  10. return params
  11. def linear_model(params, x):
  12. w0 = params[:80]
  13. b0 = params[80:160]
  14. w1 = params[160:240]
  15. b1 = params[240]
  16. h = softplus(x*w0 + b0)
  17. o = jnp.sum(h*w1) + b1
  18. return o
  19. def loss(derivative, initial_condition, params, model, x):
  20. dfdx = jax.grad(model, 1)
  21. dfdx_vect = jax.vmap(dfdx, (None, 0))
  22. model_vect = jax.vmap(model, (None, 0))
  23. eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
  24. condition_difference = model(params, 0) - initial_condition
  25. return jnp.mean(eq_difference ** 2 - condition_difference ** 2)
  26. def dfdx(x, y):
  27. return -2. * x * y
  28. key = jax.random.PRNGKey(0)
  29. inputs = np.linspace(0, 1, num=401)
  30. params = init_params()
  31. epochs = 2000
  32. learning_rate = 0.0005
  33. # Training Neural Network
  34. for epoch in tqdm(range(epochs)):
  35. grad_loss = jax.grad(loss)
  36. gradient = grad_loss(dfdx, 1., params, linear_model, inputs)
  37. params -= learning_rate*gradient
  38. model_vect = jax.vmap(linear_model, (None, 0))
  39. preds = model_vect(params, inputs)
  40. plt.plot(inputs, jnp.exp(inputs**2), label='exact')
  41. plt.plot(inputs, model_vect(params, inputs), label='approx')
  42. plt.legend()
  43. plt.show()

问题是Jax不喜欢对接受另一个函数作为参数的函数取梯度:

  1. TypeError: Argument '<function dfdx at 0x7fce88340af0>' of type <class 'function'> is not a valid JAX type.

有没有解决办法?

英文:

I am trying to make a neural network-based differential equation solver for the differential equation y' + 2xy = 0.

  1. import jax.numpy as jnp
  2. import jax
  3. import matplotlib.pyplot as plt
  4. from tqdm import tqdm
  5. import numpy as np
  6. def softplus(x):
  7. return jnp.log(1 + jnp.exp(x))
  8. def init_params():
  9. params = jax.random.normal(key, shape=(241,))
  10. return params
  11. def linear_model(params, x):
  12. w0 = params[:80]
  13. b0 = params[80:160]
  14. w1 = params[160:240]
  15. b1 = params[240]
  16. h = softplus(x*w0 + b0)
  17. o = jnp.sum(h*w1) + b1
  18. return o
  19. def loss(derivative, initial_condition, params, model, x):
  20. dfdx = jax.grad(model, 1)
  21. dfdx_vect = jax.vmap(dfdx, (None, 0))
  22. model_vect = jax.vmap(model, (None, 0))
  23. eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
  24. condition_difference = model(params, 0) - initial_condition
  25. return jnp.mean(eq_difference ** 2 - condition_difference ** 2)
  26. def dfdx(x, y):
  27. return -2. * x * y
  28. key = jax.random.PRNGKey(0)
  29. inputs = np.linspace(0, 1, num=401)
  30. params = init_params()
  31. epochs = 2000
  32. learning_rate = 0.0005
  33. # Training Neural Network
  34. for epoch in tqdm(range(epochs)):
  35. grad_loss = jax.grad(loss)
  36. gradient = grad_loss(dfdx, 1., params, linear_model, inputs)
  37. params -= learning_rate*gradient
  38. model_vect = jax.vmap(linear_model, (None, 0))
  39. preds = model_vect(params, inputs)
  40. plt.plot(inputs, jnp.exp(inputs**2), label='exact')
  41. plt.plot(inputs, model_vect(params, inputs), label='approx')
  42. plt.legend()
  43. plt.show()

The issue is that Jax doesn't like taking the gradient of a function that receives another function as an argument:

  1. TypeError: Argument '<function dfdx at 0x7fce88340af0>' of type <class 'function'> is not a valid JAX type.

Is there any workaround for this?

答案1

得分: 1

你刚刚错误地排序了参数。Jax根据第一个参数进行区分,而你不希望根据函数区分,而是根据参数。只需将它们作为第一个参数。

英文:

You just orderd arguments wrong. Jax differentiates wrt. first argument, and you don't want to differentiate wrt your function, but rather - parameters. Just make them the first argument.

  1. def loss(params, derivative, initial_condition, model, x):
  2. dfdx = jax.grad(model, 1)
  3. dfdx_vect = jax.vmap(dfdx, (None, 0))
  4. model_vect = jax.vmap(model, (None, 0))
  5. eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
  6. condition_difference = model(params, 0) - initial_condition
  7. return jnp.mean(eq_difference ** 2 - condition_difference ** 2)

huangapple
  • 本文由 发表于 2023年4月17日 02:52:40
  • 转载请务必保留本文链接:https://go.coder-hub.com/76029743.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定