英文:
Matlab vectorization with if and for loops using ODE45 to integrate
问题
I am interested in optimizing the speed of my code and when using "Run and Time" this is the function in my code that greatly impacts speed, but I have a hard time conceptualizing how to vectorize this function properly as I usually just do looping, in the attempt I've made I run into an error as it is also used in an integration, my original function is as follows and does not result in an error
function [dotStates] = ODEFunc(t,states,params)
%ODE function
% Loading in and assigning the variables from parameters
K = params(1);
N = params(2);
nn = params(3);
% magnitude of the coupling based on the number of neighbours
kn = K/nn;
w = params(4:end);
dotStates=states;
% For each oscillator
for i=1:N
% Use the oscillators natural frequency
dotStates(i) = w(i);
% For j number of neighbours
for j=(i-nn):(i+nn)
% neighbour number is positive and shorter than # of oscilators
if (j > 0) && (j < length(dotStates))
dotStates(i) = dotStates(i) + (kn * sin( states(j)-states(i) ));
end
end
end
end
I've tried following the mathworks vectorization guide: MathWorks Vectorization Guide
My attempt so far has been to follow some of the inputs of what they use, such as using a mask and have generated the following code
function [dotStates] = ODEFunc(t,states,params)
%ODE function
% Loading in and assigning the variables from parameters
K = params(1);
N = params(2);
nn = params(3);
% magnitude of the coupling based on the number of neighbours
kn = K/nn;
w = params(4:end);
dotStates=states;
% Use the oscillators natural frequency
dotStates = w';
% Mask of j states
j = (i-nn):(i+nn);
% neighbours cannot exceed boundaries
j = j(j>0 & j <= length(dotStates));
jstate = states(j);
jstate(numel(states)) = 0;
dotStates = dotStates + (kn * sin( jstate'-states ));
end
I ended up with a vector that is shorter than what is being written to, and my solution has been to just add a bunch of zeros to the "jstate" variable to make up for the difference, but that does not feel like proper vectorization. When I run the code, I get the following error which is tied to an integration step:
Warning: Colon operands must be real scalars.
In RK_ODE_2411>ODEFunc (line 99)
In RK_ODE_2411>@(t,states)ODEFunc(t,states,params)
In ode45 (line 324)
In RK_ODE_2411 (line 58)
The function is in turn used in the following segment for the integration using ODE45
%% Integration via ODE45
for K = 0:.1:Klen
params(1) = K;
K_count = K_count+1;
nn_count = 0;
for nn = nnlen:nnlen
params(3) = nn;
% index counter
nn_count = nn_count+1;
% 6th order runge kutta
sol(K_count,nn_count) = ode45(@(t,states) ODEFunc(t,states,params),tSpan,init,options);
end
end
Where line 58 is
sol(K_count,nn_count) = ode45(@(t,states) ODEFunc(t,states,params),tSpan,init,options);
EDIT: Line 99 in ODEFunc is
j = (i-nn):(i+nn);
英文:
I am interested in optimizing the speed of my code and when using "Run and Time" this is the function in my code that greatly impacts speed, but I have a hard time conceptualizing how to vectorize this function properly as I usually just do looping, in the attempt I've made I run into an error as it is also used in an integration, my original function is as follows and does not result in an error
function [dotStates] = ODEFunc(t,states,params)
%ODE function
% Loading in and assigning the variables from parameters
K = params(1);
N = params(2);
nn = params(3);
% magnitude of the coupling based on the number of neighbours
kn = K/nn;
w = params(4:end);
dotStates=states;
% For each oscillator
for i=1:N
% Use the oscillators natural frequency
dotStates(i) = w(i);
% For j number of neighbours
for j=(i-nn):(i+nn)
% neighbour number is positive and shorter than # of oscilators
if (j > 0) && (j < length(dotStates))
dotStates(i) = dotStates(i) + (kn * sin( states(j)-states(i) ));
end
end
end
end
I've tried following the mathworks vectorization guide: https://se.mathworks.com/help/matlab/matlab_prog/vectorization.html
My attempt so far has been to follow some of the inputs of what they use, such as using a mask and have generated following code
function [dotStates] = ODEFunc(t,states,params)
%ODE function
% Loading in and assigning the variables from parameters
K = params(1);
N = params(2);
nn = params(3);
% magnitude of the coupling based on the number of neighbours
kn = K/nn;
w = params(4:end);
dotStates=states;
% Use the oscillators natural frequency
dotStates = w';
% Mask of j states
j = (i-nn):(i+nn);
% neighbours cannot exceed boundaries
j = j(j>0 & j <=length(dotStates));
jstate = states(j);
jstate(numel(states)) = 0;
dotStates = dotStates + (kn * sin( jstate'-states ));
end
I ended up with a vector that is shorter than what is being written to and my solution has been to just add a bunch of zeros to the "jstate" variable to make up for the difference but that does not feel like a proper vectorization and when I run the code I get the following error which is tied to and integration step
>Warning: Colon operands must be real scalars.
> In RK_ODE_2411>ODEFunc (line 99)
In RK_ODE_2411>@(t,states)ODEFunc(t,states,params)
In ode45 (line 324)
In RK_ODE_2411 (line 58)
the function is in turn used in the following segment for the integration using ODE45
%% Integration via ODE45
for K = 0:.1:Klen
params(1) = K;
K_count = K_count+1;
nn_count = 0;
for nn = nnlen:nnlen
params(3) = nn;
% index counter
nn_count = nn_count+1;
% 6th order runge kutta
sol(K_count,nn_count) = ode45(@(t,states) ODEFunc(t,states,params),tSpan,init,options);
end
end
where line 58 is
sol(K_count,nn_count) = ode45(@(t,states) ODEFunc(t,states,params),tSpan,init,options);
EDIT: line 99 in ODEFunc is
j = (i-nn):(i+nn);
答案1
得分: 1
尝试这段代码片段
% 对于每个振荡器
for i = 1:N
% 对于邻居数为j
j = (i - nn):(i + nn);
% 邻居数是正数且比振荡器的数量短
lg = (j > 0) & (j < length(dotStates));
dotStates(i) = w(i) + sum(kn * sin(states(lg) - states(i)));
end
最重要的是,确保 dotStates
不会大于 stats
,因为这会迫使Matlab重新排列内存,这会严重减慢代码的运行速度。
英文:
Try this snippet
% For each oscillator
for i=1:N
% For j number of neighbours
j=(i-nn):(i+nn);
% neighbour number is positive and shorter than # of oscilators
lg = (j > 0) & (j < length(dotStates));
dotStates(i) = w(i) + sum(kn * sin( states(lg)-states(i) ));
end
the most important is though that dotStates
won't be larger than stats
, since this would force matlab to rearrange its memory, which slows down the code enormously.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论