JAX用于在Python中最小化2点Lennard-Jones势能的代码产生了意外结果。

huangapple go评论48阅读模式
英文:

JAX code for minimizing Lennard-Jones potential for 2 points in Python gives unexpected results

问题

我正在尝试使用JAX来解决优化问题,我试图解决一个简单的问题,即最小化仅包含2个点的Lennard-Jones势能。我将Lennard-Jones势能中的epsilon和sigma都设置为1,因此势能表达式为:F = 4(1/r^12-1/r^6),其中r是两个点之间的距离。结果应该是r = 2^(1/6),约为1.12。

我使用了JAX编写了以下代码,代码非常简单和简短。我初始猜测两个点的位置为[0, 1],我认为这是合理的(因为对于Lennard-Jones势能,如果r的猜测值太小,可能会出现无限大的问题)。正如我之前提到的,我期望在最小化后得到r约为1.12的值,但是我得到的结果是[-0.71276042 1.71276042],因此距离为2.4,显然太大了,我想知道如何修复它。我最初怀疑可能是精度的问题,所以我将数据类型更改为float64,但结果仍然相同。任何帮助将不胜感激!以下是我的代码:

import jax
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax import vmap
import matplotlib.pyplot as plt

N = 2
jax.config.update("jax_enable_x64", True)
x_init = jnp.arange(N, dtype=jnp.float64)
epsilon = 1
sigma = 1

def potential(r):
    r = jnp.where(r == 0, jnp.finfo(jnp.float64).eps, r)
    return 4 * epsilon * ((sigma/r)**12 - (sigma/r)**6)

def F(x):
    # 计算所有点之间的距离
    r = jnp.abs(x[:, None] - x[None, :])
    # 计算所有点之间的势能
    pot = vmap(vmap(potential))(r)
    # 排除对角线(距离为0)并避免重复计数,只考虑上三角部分
    pot = jnp.triu(pot, 1)
    # 求和所有势能
    total = jnp.sum(pot)
    return total

# 最小化函数
result = minimize(F, x_init, method='BFGS')

# 提取优化后的点的位置
x_solutions = result.x
print(x_solutions)

希望这可以帮助你解决问题。

英文:

I am trying to practice using JAX fo optimization problem and I am trying to do a simple problem, which is to minimize Lennard-Jones potential for just 2 points and I set both epsilon and sigma in Lennard-Jones potential equal 1, so the potential is just: F = 4(1/r^12-1/r^6) and r is the distance between the two points. And the result should be r = 2^(1/6), which is approximately 1.12.

Using JAX, I wrote following code, which is pretty simple and short, my initial guess values for two points are [0,1], which I think it is reasonable(because for Lennard-Jones potential it could be a problem because it approach infinite if r guess is too small). As I mentioned, I am expecting a value of r around 1.12 after the minimization, however, the result I get is [-0.71276042 1.71276042], so the distance is 2.4, which is clearly too big and I am wondering how can I fix it. I original doubt it might be the precision so I change the data type to float64, but the results are still the same. Any help will be greatly appreciated! Here is my code

import jax
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax import vmap
import matplotlib.pyplot as plt

N = 2
jax.config.update("jax_enable_x64", True)
x_init = jnp.arange(N, dtype=jnp.float64)
epsilon = 1
sigma = 1

def potential(r):
    r = jnp.where(r == 0, jnp.finfo(jnp.float64).eps, r)
    return 4 * epsilon * ((sigma/r)**12 - (sigma/r)**6)

def F(x):
    # Compute all pairwise distances
    r = jnp.abs(x[:, None] - x[None, :])
    # Compute all pairwise potentials
    pot = vmap(vmap(potential))(r)
    # Exclude the diagonal (distance = 0) and avoid double-counting by taking upper triangular part
    pot = jnp.triu(pot, 1)
    # Sum up all the potentials
    total = jnp.sum(pot)
    return total

# Minimize the function
print(F)
result = minimize(F, x_init, method='BFGS')

# Extract the optimized positions of the points
x_solutions = result.x
print(x_solutions)

答案1

得分: 0

这个函数对于任何无约束的基于梯度的优化器来说都很难正确优化。将一个点保持在零并在范围(0, 10]上变化另一个点,我们可以看到潜力的形状如下:

r = jnp.linspace(0.1, 5.0, 1000)
plt.plot(r, jax.vmap(lambda ri: F(jnp.array([0, ri])))(r))
plt.ylim(-2, 10)

JAX用于在Python中最小化2点Lennard-Jones势能的代码产生了意外结果。

在最小值的左侧,梯度迅速趋向负无穷,这意味着对于几乎任何合理的步长,优化器很可能会超过最小值。然后在右侧,如果优化器走得再远一点,梯度趋于零,这意味着对于几乎任何合理的步长,优化器将卡在潜力几乎没有变化的区域。

再加上你设置了具有两个自由度的模型在一个退化的潜力中,基于梯度的优化方法失败并不奇怪。

在这里,你可以通过最小化平移潜力的对数来取得一些进展,这会平滑陡峭的梯度,并让BFGS最小化器找到一个期望的最小值:

result = minimize(lambda x: jnp.log(2 + F(x)), x_init, method='BFGS')
print(result.x)
# [-0.06123102  1.06123102]

但总的来说,我的建议可能是选择受限制的优化方法,也许是JAXOpt受限制优化方法之一,你可以在其中排除参数空间中的问题区域。

英文:

This function is one that would be very difficult for any unconstrained gradient-based optimizer to correctly optimize. Holding one point at zero and varying the other point on the range (0, 10], we see the potential looks like this:

r = jnp.linspace(0.1, 5.0, 1000)
plt.plot(r, jax.vmap(lambda ri: F(jnp.array([0, ri])))(r))
plt.ylim(-2, 10)

JAX用于在Python中最小化2点Lennard-Jones势能的代码产生了意外结果。

To the left of the minimum, the gradient quickly diverges to negative infinity, meaning for nearly any reasonable step size, the optimizer will likely overshoot the minimum. Then on the right side, if the optimizer goes even a few units too far, the gradient tends to zero, meaning for nearly any reasonable step size, the optimizer will get stuck in a regime where the potential has almost no variation.

Add to this the fact that you've set up the model with two degrees of freedom in a degenerate potential, and it's not surprising that gradient-based optimization methods are failing.

You can make some progress here by minimizing the log of the shifted potential, which has the effect of smoothing the steep gradients, and lets the BFGS minimizer find an expected minimum:

result = minimize(lambda x: jnp.log(2 + F(x)), x_init, method='BFGS')
print(result.x)
# [-0.06123102  1.06123102]

But in general my suggestion would probably be to opt for a constrained optimization approach instead, perhaps one of the JAXOpt constrained optimization methods, where you can rule-out problematic regions of the parameter space.

huangapple
  • 本文由 发表于 2023年5月29日 04:06:06
  • 转载请务必保留本文链接:https://go.coder-hub.com/76353392.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定