英文:
How to extract the function from sympy.diff?
问题
I'm here to assist with your translation request. Here is the translated code portion:
我正在尝试使用`sympy.diff`来计算导数函数,然后解决此函数中的`'x'`,使导数函数`func(x) = 0`的根。然而,解决这个导数函数非常慢,因为它返回了五个解,但我只需要最接近固定值`x0`的解。
```python
import sympy
def diff_dist_func(a, b, c):
x = sympy.Symbol('x')
x0 = sympy.Symbol('x0')
y0 = sympy.Symbol('y0')
dist = ((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2) ** (1 / 2)
return sympy.diff(dist, x)
a = -0.00020129919480721813
b = 0.10107634020780536
c = -12.305150031126267
shortest_dist = diff_dist_func(a, b, c)
x0 = 252.3007982720215
y0 = 96.55526056735049
solve_shortest_dist = shortest_dist.evalf(subs={'x0': x0, 'y0': y0}) # 构建导数函数
solve_x = sympy.solve(solve_shortest_dist, sympy.Symbol('x'), simplify=False, rational=False) # 这里解决导数函数非常慢。
为了加速解决方案,我尝试使用scipy.optimize.fsolve
,它能够为func(x) = 0
的根提供一个起始估计值x0
。因此,我用fsolve(solve_shortest_dist, np.array([x0]))
替换了sympy.solve
,但出现了错误TypeError: 'Mul' object is not callable
。如何从sympy.diff
的输出中提取导数函数,使其能够被scipy.optimize.fsolve
解决?或者有没有任何方法来加速解决过程?
英文:
I'm trying to calculate a derivative function using sympy.diff
, and then solve the 'x'
in this function letting the roots of derivative function func(x) = 0
. However, solve this derivative function is very slow, because it returns five solutions, but I only need the solution most close to a fixed value x0
.
import sympy
def diff_dist_func(a, b, c):
x = sympy.Symbol('x')
x0 = sympy.Symbol('x0')
y0 = sympy.Symbol('y0')
dist = ((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2) ** (1 / 2)
return sympy.diff(dist, x)
a = -0.00020129919480721813
b=0.10107634020780536
c=-12.305150031126267
shortest_dist = diff_dist_func(a, b, c)
x0=252.3007982720215
y0=96.55526056735049
solve_shortest_dist = shortest_dist.evalf(subs={'x0': x0, 'y0': y0}) # build the derivative function
solve_x = sympy.solve(solve_shortest_dist, sympy.Symbol('x'), simplify=False, rational=False) # Here solve the derivative function is very slow.
To speed up the solution, I try to use scipy.optimize.fsolve
, which is able to give a starting estimate x0
for the roots of func(x) = 0
. So I replace sympy.solve
with fsolve(solve_shortest_dist, np.array([x0]))
, but an error occured TypeError: 'Mul' object is not callable
. How to extract the derivative function from the output of sympy.diff, enable it to be solved by scipy.optimize.fsolve
? Or are there any way to speed up the solving process?
答案1
得分: 2
我使用SymPy 1.12,`solve`非常快速,暂时不需要转向数值库:
def diff_dist_func(a, b, c):
x = sympy.Symbol('x')
x0 = sympy.Symbol('x0')
y0 = sympy.Symbol('y0')
dist = sympy.sqrt((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2)
return sympy.diff(dist, x)
a = -0.00020129919480721813
b = 0.10107634020780536
c = -12.305150031126267
shortest_dist = diff_dist_func(a, b, c)
x0 = 252.3007982720215
y0 = 96.55526056735049
sols = sympy.solve(shortest_dist.subs({'x0': x0, 'y0': y0}), x)
sols
# 输出: [-5.99737132198727,
# 76.9370574126698,
# 252.303674956197,
# 256.81160524887 - 13.1689464528248*I,
# 256.81160524887 + 13.1689464528248*I]
请注意复数解。 让我们计算与x0
的误差:
sols = np.array(sols, dtype=complex)
error = np.abs(sols - x0)
error
# 输出: array([2.58298170e+02, 1.75363741e+02, 2.87668418e-03, 1.39200765e+01,
1.39200765e+01])
最后,提取最接近x0
的解:
idx = np.argmin(error)
sols[idx]
# 输出: 252.303674956197
英文:
I'm using the SymPy 1.12, and solve
is really fast, no need to move onto numerical libraries for now:
def diff_dist_func(a, b, c):
x = sympy.Symbol('x')
x0 = sympy.Symbol('x0')
y0 = sympy.Symbol('y0')
dist = sympy.sqrt((x - x0) ** 2 + (a * x ** 3 + b * x ** 2 + c * x - y0) ** 2)
return sympy.diff(dist, x)
a = -0.00020129919480721813
b = 0.10107634020780536
c = -12.305150031126267
shortest_dist = diff_dist_func(a, b, c)
x0 = 252.3007982720215
y0 = 96.55526056735049
sols = sympy.solve(shortest_dist.subs({'x0': x0, 'y0': y0}), x)
sols
# out: [-5.99737132198727,
# 76.9370574126698,
# 252.303674956197,
# 256.81160524887 - 13.1689464528248*I,
# 256.81160524887 + 13.1689464528248*I]
Note the complex solutions. Let's compute the error wrt x0
:
sols = np.array(sols, dtype=complex)
error = np.abs(sols - x0)
error
# out: array([2.58298170e+02, 1.75363741e+02, 2.87668418e-03, 1.39200765e+01,
1.39200765e+01])
Finally, extract the closest solution to x0
:
idx = np.argmin(error)
sols[idx]
# out: 252.303674956197
答案2
得分: 1
如果您需要速度并希望获得数值解,scipy.optimize.root
(fsolve
是一个遗留函数,因此我更喜欢使用更新的 root
函数)可以用于 lambdify
函数的结果。lambdify
将 sympy 函数转换为 numpy 或 scipy 函数(有选项可以使用其他包)。一旦转换完成,scipy 的根查找器可以正常工作。
from scipy.optimize import root
# 在此插入您的其他代码
solve_shortest_dist_f = sympy.lambdify(sympy.Symbol("x"), solve_shortest_dist)
res = root(solve_shortest_dist_f, x0)
print(res.x[0]) # 252.30367495620214
res
变量包含有关解的一些信息,但 res.x[0]
是您的根。
在计时方面,这需要 6.01 毫秒,而 sympy 解需要 22.2 毫秒。不过,sympy 找到了所有的根,而 scipy 只找到了一个。
英文:
If you do need the speed and want a numerical solution, scipy.optimize.root
(fsolve
is a legacy function, so I prefer using the newer root
function) can work on the result of the lambdify
function. lambdify
converts a sympy function into a numpy or scipy function (there are options to use other packages). Once converted, the scipy root-finder can work just fine.
from scipy.optimize import root
# your other code here
solve_shortest_dist_f = sympy.lambdify(sympy.Symbol("x"), solve_shortest_dist)
res = root(solve_shortest_dist_f, x0)
print(res.x[0]) # 252.30367495620214
The res
variable contains some information about the solve, but res.x[0]
is your root.
Timing this, it takes 6.01ms vs sympy solve which takes 22.2ms. Though, sympy is finding all the roots while scipy is just finding one.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论