如何在DifferentialEquations.jl中实现一个积分终止回调以解决ODE?

huangapple go评论54阅读模式
英文:

How can I implement an integration termination callback in DifferentialEquations.jl to solve an ODE?

问题

需要帮助在DifferentialEquations.jl中实现一个积分终止回调函数。

你好,

我有以下代码:

    function height(dh, h, p, t)
	dh[1] = -1*sqrt(h[1])
    end

    h0 = [14]
    tspan = (0.0, 10.0)

    prob = ODEProblem(height, h0, tspan, p)

但是当我尝试使用以下方式解决ODE时:

    sol = solve(prob)

我得到了以下错误消息:

"DomainError with -0.019520634518403183:
sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."

显然,在积分过程中,h[1] 变为负值,从而引发了错误。

我尝试通过积分终止回调函数来解决这个问题,因为我只希望得到h(t) >= 0的解决方案。

以下是我的回调代码:

    condition(h, t, integrator) = h[1]
    affect!(integrator) = terminate!(integrator)
    cb = ContinuousCallback(condition, affect!)

我以为这会在h[1] = 0时终止积分,但是当我尝试:

    sol = solve(prob, callback = cb)

时,我得到了相同的错误。我是第一次使用这些回调特性,显然在实现它们方面有一些我不理解的地方。如果你有一些关于我需要改变/修改我的代码以使其工作的想法,我将感激不尽。

谢谢,
Gary

英文:

Need help implementing a integration termination callback in DifferentialEquations.jl.

Greetings,

I have the code

    function height(dh, h, p, t)
	dh[1] = -1*sqrt(h[1])
    end

    h0 = [14]
    tspan = (0.0, 10.0)

    prob = ODEProblem(height, h0, tspan, p)

but when I try solving the ODE with:

    sol = solve(prob)

I get:

"DomainError with -0.019520634518403183:
sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."

Evidently, during the integration process, h[1] becomes negative valued, thus causing the error.

I tried mitigating the problem with a integration termination callback, since I only want the solution for h(t) >= 0.

Here's my callback code:

    condition(h, t, integrator) = h[1]
    affect!(integrator) = terminate!(integrator)
    cb = ContinuousCallback(condition, affect!)

I thought this would terminate the integration at the timestep when h[1] = 0, but when I then tried:

    sol = solve(prob, callback = cb)

I get the same error. I'm new to using these callback features, so clearly there is something I'm not understanding in implementing them. If you have some idea of what I need to change/amend my code to get it working, I would appreciate your feedback.

Thanks,
Gary

答案1

得分: 1

你的方程只有在满足 h(t)>0 的情况下才有解。如果我们对方程进行解析求解,得到:
y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C

其中 h₀=14,我们得到 2√14=C。
因此 -t+C 必须大于或等于0,即 t<=C <=> t<=2√14 =7.483314773547883。

根据这个常见问题解答(FAQ):https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/,我在 ODE 定义中用 sqrt(max(0, h[1])) 替换了 sqrt(h[1])。改变后的代码如下:

using DifferentialEquations, Plots
function height(dh, h, p, t)
    dh[1] =   -1*sqrt(max(0, h[1]))
end
h0 = [14]
p=[0]
tspan = (0, 10)
condition(h, t, integrator)=h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
prob = ODEProblem(height, h0, tspan, p);
sol = solve(prob, callback=cb)

sol.t 是:

0.13020240538639785
1.0122592548078624
2.492004954051219
3.9874743468989706
5.525543090227709
6.22154035990045
6.677342684567405
6.9977821858188936
7.176766608936562
7.281444926891483
7.36468081569681
7.415386846800273
7.449319963724896
7.47183932499115
7.479193894248527
7.479193894248527

也就是说,最后的 t 接近于 2sqrt(14)。再进行一次时间步长会超过 2sqrt(14)。

绘制 sol,可以将图形延伸到 [0,10]:

plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)", 
          framestyle=:box, size=(400,300), legend=false)

但是使用以下代码:

plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)", 
          framestyle=:box, size=(400,300), legend=false)

我们得到了给定 ODE 的解。

英文:

Your equation has a solution only for those t that make h(t)>0. If we solve the equation analytically we have:
y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C

with h₀=14 we get 2√14=C.
hence -t+C must be greater or equal to 0, i.e. t<=C <=> t <=2√14 =7.483314773547883.
Following this FAQ: https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/ I replaced sqrt(h[1]) in the ODE definition by sqrt(max(0, h[1])). With this change the code is as follows:

using DifferentialEquations, Plots
function height(dh, h, p, t)
    dh[1] =   -1*sqrt(max(0, h[1]))
end
h0 = [14]
p=[0]
tspan = (0, 10)
condition(h, t, integrator)=h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
prob = ODEProblem(height, h0, tspan, p);
sol = solve(prob, callback=cb)

sol.t is:

0.13020240538639785
 1.0122592548078624
 2.492004954051219
 3.9874743468989706
 5.525543090227709
 6.22154035990045
 6.677342684567405
 6.9977821858188936
 7.176766608936562
 7.281444926891483
 7.36468081569681
 7.415386846800273
 7.449319963724896
 7.47183932499115
 7.479193894248527
 7.479193894248527

i.e. the last t is close to 2sqrt(14). One more time step will exceed 2sqrt(14).
Plotting sol, the plot will display an extended sol to [0,10]:

plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel=&quot;Time (t)&quot;, ylabel=&quot;y(t)&quot;, 
          framestyle=:box, size=(400,300), legend=false)

but with:

plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel=&quot;Time (t)&quot;, ylabel=&quot;y(t)&quot;, 
          framestyle=:box, size=(400,300), legend=false)

we get the solution of the given ODE

huangapple
  • 本文由 发表于 2023年5月28日 08:50:05
  • 转载请务必保留本文链接:https://go.coder-hub.com/76349566.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定