英文:
How can I implement an integration termination callback in DifferentialEquations.jl to solve an ODE?
问题
需要帮助在DifferentialEquations.jl中实现一个积分终止回调函数。
你好,
我有以下代码:
function height(dh, h, p, t)
dh[1] = -1*sqrt(h[1])
end
h0 = [14]
tspan = (0.0, 10.0)
prob = ODEProblem(height, h0, tspan, p)
但是当我尝试使用以下方式解决ODE时:
sol = solve(prob)
我得到了以下错误消息:
"DomainError with -0.019520634518403183:
sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."
显然,在积分过程中,h[1] 变为负值,从而引发了错误。
我尝试通过积分终止回调函数来解决这个问题,因为我只希望得到h(t) >= 0的解决方案。
以下是我的回调代码:
condition(h, t, integrator) = h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
我以为这会在h[1] = 0时终止积分,但是当我尝试:
sol = solve(prob, callback = cb)
时,我得到了相同的错误。我是第一次使用这些回调特性,显然在实现它们方面有一些我不理解的地方。如果你有一些关于我需要改变/修改我的代码以使其工作的想法,我将感激不尽。
谢谢,
Gary
英文:
Need help implementing a integration termination callback in DifferentialEquations.jl.
Greetings,
I have the code
function height(dh, h, p, t)
dh[1] = -1*sqrt(h[1])
end
h0 = [14]
tspan = (0.0, 10.0)
prob = ODEProblem(height, h0, tspan, p)
but when I try solving the ODE with:
sol = solve(prob)
I get:
"DomainError with -0.019520634518403183:
sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."
Evidently, during the integration process, h[1] becomes negative valued, thus causing the error.
I tried mitigating the problem with a integration termination callback, since I only want the solution for h(t) >= 0.
Here's my callback code:
condition(h, t, integrator) = h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
I thought this would terminate the integration at the timestep when h[1] = 0, but when I then tried:
sol = solve(prob, callback = cb)
I get the same error. I'm new to using these callback features, so clearly there is something I'm not understanding in implementing them. If you have some idea of what I need to change/amend my code to get it working, I would appreciate your feedback.
Thanks,
Gary
答案1
得分: 1
你的方程只有在满足 h(t)>0 的情况下才有解。如果我们对方程进行解析求解,得到:
y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C
其中 h₀=14,我们得到 2√14=C。
因此 -t+C 必须大于或等于0,即 t<=C <=> t<=2√14 =7.483314773547883。
根据这个常见问题解答(FAQ):https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/,我在 ODE 定义中用 sqrt(max(0, h[1])) 替换了 sqrt(h[1])。改变后的代码如下:
using DifferentialEquations, Plots
function height(dh, h, p, t)
dh[1] = -1*sqrt(max(0, h[1]))
end
h0 = [14]
p=[0]
tspan = (0, 10)
condition(h, t, integrator)=h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
prob = ODEProblem(height, h0, tspan, p);
sol = solve(prob, callback=cb)
sol.t 是:
0.13020240538639785
1.0122592548078624
2.492004954051219
3.9874743468989706
5.525543090227709
6.22154035990045
6.677342684567405
6.9977821858188936
7.176766608936562
7.281444926891483
7.36468081569681
7.415386846800273
7.449319963724896
7.47183932499115
7.479193894248527
7.479193894248527
也就是说,最后的 t 接近于 2sqrt(14)。再进行一次时间步长会超过 2sqrt(14)。
绘制 sol,可以将图形延伸到 [0,10]:
plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
framestyle=:box, size=(400,300), legend=false)
但是使用以下代码:
plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
framestyle=:box, size=(400,300), legend=false)
我们得到了给定 ODE 的解。
英文:
Your equation has a solution only for those t that make h(t)>0. If we solve the equation analytically we have:
y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C
with h₀=14 we get 2√14=C.
hence -t+C must be greater or equal to 0, i.e. t<=C <=> t <=2√14 =7.483314773547883.
Following this FAQ: https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/ I replaced sqrt(h[1]) in the ODE definition by sqrt(max(0, h[1])). With this change the code is as follows:
using DifferentialEquations, Plots
function height(dh, h, p, t)
dh[1] = -1*sqrt(max(0, h[1]))
end
h0 = [14]
p=[0]
tspan = (0, 10)
condition(h, t, integrator)=h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
prob = ODEProblem(height, h0, tspan, p);
sol = solve(prob, callback=cb)
sol.t is:
0.13020240538639785
1.0122592548078624
2.492004954051219
3.9874743468989706
5.525543090227709
6.22154035990045
6.677342684567405
6.9977821858188936
7.176766608936562
7.281444926891483
7.36468081569681
7.415386846800273
7.449319963724896
7.47183932499115
7.479193894248527
7.479193894248527
i.e. the last t is close to 2sqrt(14). One more time step will exceed 2sqrt(14).
Plotting sol, the plot will display an extended sol to [0,10]:
plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
framestyle=:box, size=(400,300), legend=false)
but with:
plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
framestyle=:box, size=(400,300), legend=false)
we get the solution of the given ODE
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论