从sympy.diff中提取函数的方法是什么?

huangapple go评论74阅读模式
英文:

How to extract the function from sympy.diff?

问题

I'm here to assist with your translation request. Here is the translated code portion:

我正在尝试使用`sympy.diff`来计算导数函数然后解决此函数中的`'x'`,使导数函数`func(x) = 0`的根然而解决这个导数函数非常慢因为它返回了五个解但我只需要最接近固定值`x0`的解

```python
import sympy

def diff_dist_func(a, b, c):
    x = sympy.Symbol('x')
    x0 = sympy.Symbol('x0')
    y0 = sympy.Symbol('y0')

    dist = ((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2) ** (1 / 2)
    return sympy.diff(dist, x)

a = -0.00020129919480721813
b = 0.10107634020780536
c = -12.305150031126267
shortest_dist = diff_dist_func(a, b, c)

x0 = 252.3007982720215
y0 = 96.55526056735049
solve_shortest_dist = shortest_dist.evalf(subs={'x0': x0, 'y0': y0})  # 构建导数函数
solve_x = sympy.solve(solve_shortest_dist, sympy.Symbol('x'), simplify=False, rational=False)  # 这里解决导数函数非常慢。

为了加速解决方案,我尝试使用scipy.optimize.fsolve,它能够为func(x) = 0的根提供一个起始估计值x0。因此,我用fsolve(solve_shortest_dist, np.array([x0]))替换了sympy.solve,但出现了错误TypeError: 'Mul' object is not callable。如何从sympy.diff的输出中提取导数函数,使其能够被scipy.optimize.fsolve解决?或者有没有任何方法来加速解决过程?

英文:

I'm trying to calculate a derivative function using sympy.diff, and then solve the 'x' in this function letting the roots of derivative function func(x) = 0. However, solve this derivative function is very slow, because it returns five solutions, but I only need the solution most close to a fixed value x0.

import sympy

def diff_dist_func(a, b, c):
    x = sympy.Symbol('x')
    x0 = sympy.Symbol('x0')
    y0 = sympy.Symbol('y0')

    dist = ((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2) ** (1 / 2)
    return sympy.diff(dist, x)

a = -0.00020129919480721813
b=0.10107634020780536
c=-12.305150031126267
shortest_dist = diff_dist_func(a, b, c)
    
x0=252.3007982720215
y0=96.55526056735049
solve_shortest_dist = shortest_dist.evalf(subs={'x0': x0, 'y0': y0})  # build the derivative function
solve_x = sympy.solve(solve_shortest_dist, sympy.Symbol('x'), simplify=False, rational=False)  # Here solve the derivative function is very slow.

To speed up the solution, I try to use scipy.optimize.fsolve, which is able to give a starting estimate x0 for the roots of func(x) = 0. So I replace sympy.solve with fsolve(solve_shortest_dist, np.array([x0])), but an error occured TypeError: 'Mul' object is not callable. How to extract the derivative function from the output of sympy.diff, enable it to be solved by scipy.optimize.fsolve? Or are there any way to speed up the solving process?

答案1

得分: 2

我使用SymPy 1.12,`solve`非常快速暂时不需要转向数值库

def diff_dist_func(a, b, c):
    x = sympy.Symbol('x')
    x0 = sympy.Symbol('x0')
    y0 = sympy.Symbol('y0')

    dist = sympy.sqrt((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2)
    return sympy.diff(dist, x)

a = -0.00020129919480721813
b = 0.10107634020780536
c = -12.305150031126267
shortest_dist = diff_dist_func(a, b, c)

x0 = 252.3007982720215
y0 = 96.55526056735049
sols = sympy.solve(shortest_dist.subs({'x0': x0, 'y0': y0}), x)
sols
# 输出: [-5.99737132198727,
#  76.9370574126698,
#  252.303674956197,
#  256.81160524887 - 13.1689464528248*I,
#  256.81160524887 + 13.1689464528248*I]

请注意复数解。 让我们计算与x0的误差:

sols = np.array(sols, dtype=complex)
error = np.abs(sols - x0)
error
# 输出: array([2.58298170e+02, 1.75363741e+02, 2.87668418e-03, 1.39200765e+01,
       1.39200765e+01])

最后,提取最接近x0的解:

idx = np.argmin(error)
sols[idx]
# 输出: 252.303674956197
英文:

I'm using the SymPy 1.12, and solve is really fast, no need to move onto numerical libraries for now:

def diff_dist_func(a, b, c):
    x = sympy.Symbol('x')
    x0 = sympy.Symbol('x0')
    y0 = sympy.Symbol('y0')

    dist = sympy.sqrt((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2)
    return sympy.diff(dist, x)

a = -0.00020129919480721813
b = 0.10107634020780536
c = -12.305150031126267
shortest_dist = diff_dist_func(a, b, c)

x0 = 252.3007982720215
y0 = 96.55526056735049
sols = sympy.solve(shortest_dist.subs({'x0': x0, 'y0': y0}), x)
sols
# out: [-5.99737132198727,
#  76.9370574126698,
#  252.303674956197,
#  256.81160524887 - 13.1689464528248*I,
#  256.81160524887 + 13.1689464528248*I]

Note the complex solutions. Let's compute the error wrt x0:

sols = np.array(sols, dtype=complex)
error = np.abs(sols - x0)
error
# out: array([2.58298170e+02, 1.75363741e+02, 2.87668418e-03, 1.39200765e+01,
       1.39200765e+01])

Finally, extract the closest solution to x0:

idx = np.argmin(error)
sols[idx]
# out: 252.303674956197

答案2

得分: 1

如果您需要速度并希望获得数值解,scipy.optimize.rootfsolve 是一个遗留函数,因此我更喜欢使用更新的 root 函数)可以用于 lambdify 函数的结果。lambdify 将 sympy 函数转换为 numpy 或 scipy 函数(有选项可以使用其他包)。一旦转换完成,scipy 的根查找器可以正常工作。

from scipy.optimize import root

# 在此插入您的其他代码

solve_shortest_dist_f = sympy.lambdify(sympy.Symbol("x"), solve_shortest_dist)
res = root(solve_shortest_dist_f, x0)
print(res.x[0]) # 252.30367495620214

res 变量包含有关解的一些信息,但 res.x[0] 是您的根。

在计时方面,这需要 6.01 毫秒,而 sympy 解需要 22.2 毫秒。不过,sympy 找到了所有的根,而 scipy 只找到了一个。

英文:

If you do need the speed and want a numerical solution, scipy.optimize.root (fsolve is a legacy function, so I prefer using the newer root function) can work on the result of the lambdify function. lambdify converts a sympy function into a numpy or scipy function (there are options to use other packages). Once converted, the scipy root-finder can work just fine.

from scipy.optimize import root

# your other code here

solve_shortest_dist_f = sympy.lambdify(sympy.Symbol("x"), solve_shortest_dist)
res = root(solve_shortest_dist_f, x0)
print(res.x[0]) # 252.30367495620214

The res variable contains some information about the solve, but res.x[0] is your root.

Timing this, it takes 6.01ms vs sympy solve which takes 22.2ms. Though, sympy is finding all the roots while scipy is just finding one.

huangapple
  • 本文由 发表于 2023年6月29日 14:47:22
  • 转载请务必保留本文链接:https://go.coder-hub.com/76578636.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定