如何在DifferentialEquations.jl中实现一个积分终止回调以解决ODE?

huangapple go评论88阅读模式
英文:

How can I implement an integration termination callback in DifferentialEquations.jl to solve an ODE?

问题

需要帮助在DifferentialEquations.jl中实现一个积分终止回调函数。

你好,

我有以下代码:

  1. function height(dh, h, p, t)
  2. dh[1] = -1*sqrt(h[1])
  3. end
  4. h0 = [14]
  5. tspan = (0.0, 10.0)
  6. prob = ODEProblem(height, h0, tspan, p)

但是当我尝试使用以下方式解决ODE时:

  1. sol = solve(prob)

我得到了以下错误消息:

"DomainError with -0.019520634518403183:
sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."

显然,在积分过程中,h[1] 变为负值,从而引发了错误。

我尝试通过积分终止回调函数来解决这个问题,因为我只希望得到h(t) >= 0的解决方案。

以下是我的回调代码:

  1. condition(h, t, integrator) = h[1]
  2. affect!(integrator) = terminate!(integrator)
  3. cb = ContinuousCallback(condition, affect!)

我以为这会在h[1] = 0时终止积分,但是当我尝试:

  1. sol = solve(prob, callback = cb)

时,我得到了相同的错误。我是第一次使用这些回调特性,显然在实现它们方面有一些我不理解的地方。如果你有一些关于我需要改变/修改我的代码以使其工作的想法,我将感激不尽。

谢谢,
Gary

英文:

Need help implementing a integration termination callback in DifferentialEquations.jl.

Greetings,

I have the code

  1. function height(dh, h, p, t)
  2. dh[1] = -1*sqrt(h[1])
  3. end
  4. h0 = [14]
  5. tspan = (0.0, 10.0)
  6. prob = ODEProblem(height, h0, tspan, p)

but when I try solving the ODE with:

  1. sol = solve(prob)

I get:

"DomainError with -0.019520634518403183:
sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."

Evidently, during the integration process, h[1] becomes negative valued, thus causing the error.

I tried mitigating the problem with a integration termination callback, since I only want the solution for h(t) >= 0.

Here's my callback code:

  1. condition(h, t, integrator) = h[1]
  2. affect!(integrator) = terminate!(integrator)
  3. cb = ContinuousCallback(condition, affect!)

I thought this would terminate the integration at the timestep when h[1] = 0, but when I then tried:

  1. sol = solve(prob, callback = cb)

I get the same error. I'm new to using these callback features, so clearly there is something I'm not understanding in implementing them. If you have some idea of what I need to change/amend my code to get it working, I would appreciate your feedback.

Thanks,
Gary

答案1

得分: 1

你的方程只有在满足 h(t)>0 的情况下才有解。如果我们对方程进行解析求解,得到:
y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C

其中 h₀=14,我们得到 2√14=C。
因此 -t+C 必须大于或等于0,即 t<=C <=> t<=2√14 =7.483314773547883。

根据这个常见问题解答(FAQ):https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/,我在 ODE 定义中用 sqrt(max(0, h[1])) 替换了 sqrt(h[1])。改变后的代码如下:

  1. using DifferentialEquations, Plots
  2. function height(dh, h, p, t)
  3. dh[1] = -1*sqrt(max(0, h[1]))
  4. end
  5. h0 = [14]
  6. p=[0]
  7. tspan = (0, 10)
  8. condition(h, t, integrator)=h[1]
  9. affect!(integrator) = terminate!(integrator)
  10. cb = ContinuousCallback(condition, affect!)
  11. prob = ODEProblem(height, h0, tspan, p);
  12. sol = solve(prob, callback=cb)

sol.t 是:

  1. 0.13020240538639785
  2. 1.0122592548078624
  3. 2.492004954051219
  4. 3.9874743468989706
  5. 5.525543090227709
  6. 6.22154035990045
  7. 6.677342684567405
  8. 6.9977821858188936
  9. 7.176766608936562
  10. 7.281444926891483
  11. 7.36468081569681
  12. 7.415386846800273
  13. 7.449319963724896
  14. 7.47183932499115
  15. 7.479193894248527
  16. 7.479193894248527

也就是说,最后的 t 接近于 2sqrt(14)。再进行一次时间步长会超过 2sqrt(14)。

绘制 sol,可以将图形延伸到 [0,10]:

  1. plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
  2. framestyle=:box, size=(400,300), legend=false)

但是使用以下代码:

  1. plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
  2. framestyle=:box, size=(400,300), legend=false)

我们得到了给定 ODE 的解。

英文:

Your equation has a solution only for those t that make h(t)>0. If we solve the equation analytically we have:
y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C

with h₀=14 we get 2√14=C.
hence -t+C must be greater or equal to 0, i.e. t<=C <=> t <=2√14 =7.483314773547883.
Following this FAQ: https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/ I replaced sqrt(h[1]) in the ODE definition by sqrt(max(0, h[1])). With this change the code is as follows:

  1. using DifferentialEquations, Plots
  2. function height(dh, h, p, t)
  3. dh[1] = -1*sqrt(max(0, h[1]))
  4. end
  5. h0 = [14]
  6. p=[0]
  7. tspan = (0, 10)
  8. condition(h, t, integrator)=h[1]
  9. affect!(integrator) = terminate!(integrator)
  10. cb = ContinuousCallback(condition, affect!)
  11. prob = ODEProblem(height, h0, tspan, p);
  12. sol = solve(prob, callback=cb)

sol.t is:

  1. 0.13020240538639785
  2. 1.0122592548078624
  3. 2.492004954051219
  4. 3.9874743468989706
  5. 5.525543090227709
  6. 6.22154035990045
  7. 6.677342684567405
  8. 6.9977821858188936
  9. 7.176766608936562
  10. 7.281444926891483
  11. 7.36468081569681
  12. 7.415386846800273
  13. 7.449319963724896
  14. 7.47183932499115
  15. 7.479193894248527
  16. 7.479193894248527

i.e. the last t is close to 2sqrt(14). One more time step will exceed 2sqrt(14).
Plotting sol, the plot will display an extended sol to [0,10]:

  1. plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel=&quot;Time (t)&quot;, ylabel=&quot;y(t)&quot;,
  2. framestyle=:box, size=(400,300), legend=false)

but with:

  1. plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel=&quot;Time (t)&quot;, ylabel=&quot;y(t)&quot;,
  2. framestyle=:box, size=(400,300), legend=false)

we get the solution of the given ODE

huangapple
  • 本文由 发表于 2023年5月28日 08:50:05
  • 转载请务必保留本文链接:https://go.coder-hub.com/76349566.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定