Python流程算法

huangapple go评论86阅读模式
英文:

Python streamline algorithm

问题

目标:

我有两个数组 vxvy,分别表示速度分量。我想编写一个流线算法:

  1. 输入一个点的坐标 (seed)
  2. 根据其速度分量评估哪些像素位于输入点的路径上
  3. 返回seed点路径上点的索引

问题/疑问:

我最初编写了一个前向欧拉算法,但它对我的问题解决得非常差。我被建议将我的问题视为一个常微分方程(ODE),其中 dx/dt = v_x(t) 和 dy/dt = v_y(t)。我可以进行速度的插值,但在使用Scipy解ODE时遇到了困难。我应该怎么做?

自制算法:

我有两个数组 vxvy,表示速度分量。当一个数组中有NaN值时,另一个数组中也有NaN值。我从一个起始点seed开始,想要跟踪这个点根据速度分量经过的单元格。我对速度分量进行插值,以便将它们输入ODE求解器。

示例:

这段代码测试了一个10x11速度数组的算法。我在ODE求解器这一步卡住了。

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from scipy.interpolate import RegularGridInterpolator
  4. from scipy.integrate import odeint
  5. # 创建坐标
  6. x = np.linspace(0, 10, 100)
  7. y = np.linspace(11, 20, 90)
  8. Y, X = np.meshgrid(x, y)
  9. # 创建速度场
  10. vx = -1 - X**2 + Y
  11. vy = 1 + X - Y**2
  12. # 起始点
  13. J = 5
  14. I = 14
  15. # 插值速度场
  16. interpvx = RegularGridInterpolator((y, x), vx)
  17. interpvy = RegularGridInterpolator((y, x), vy)
  18. # 解ODE以获取点的路径,但我不知道该在参数t中放什么
  19. # solx = odeint(interpvx, interpvx((I, J)), np.linspace(0, 1, 501))
  20. # soly = odeint(interpvy, interpvx((I, J)), np.linspace(0, 1, 501))
英文:

Goal:

I have 2 arrays vx and vy representing velocity components. I want to write a streamline algorithm:

  1. Input the coordinates of a point (seed)
  2. Evaluate which pixels are on the path of the input point based on its velocity components
  3. Return the indices of the points in the path of the seed point

Issue/Question:

I initially wrote a Euler-forward algorithm that was solving very poorly my problem. I was advised to consider my problem as an Ordinary Differential Equation (ODE) where dx/dt = v_x(t) and dy/dt = v_y(t).
I can interpolate my velocities but struggle with solving the ODE with Scipy. How could I do that ?

Homemade algorithm:

I have 2 arrays vx and vy representing velocity components. When one has a NaN, the other has one too. I have a point from which I start, the seed point. I want to track which cells this point went through based on the velocity components.
I interpolate the velocity components vx and vy in order to input them in an ODE solver.

Example:

This code tests the algorithm for a 10x11 velocities array. I am blocked at the ODE solver.

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from scipy.interpolate import RegularGridInterpolator
  4. from scipy.integrate import odeint
  5. # Create coordinates
  6. x = np.linspace(0, 10, 100)
  7. y = np.linspace(11, 20, 90)
  8. Y, X = np.meshgrid(x, y)
  9. # Create velocity fields
  10. vx = -1 - X**2 + Y
  11. vy = 1 + X - Y**2
  12. # Seed point
  13. J = 5
  14. I = 14
  15. # Interpolate the velocity fields
  16. interpvx = RegularGridInterpolator((y,x), vx)
  17. interpvy = RegularGridInterpolator((y,x), vy)
  18. # Solve the ODE to get the point's path, but I don't know what to put for the parameter t
  19. #solx = odeint(interpvx, interpvx((I,J)), np.linspace(0,1,501))
  20. #soly = odeint(interpvy, interpvx((I,J)), np.linspace(0,1,501))

答案1

得分: 0

以下是代码的翻译部分:

  1. 我与某人合作他提出了一个解决方案我并不是要为这个答案而取得荣誉但考虑到 Python 没有一个直接的算法希望这对某人有用
  2. import numpy as np
  3. # 创建坐标
  4. x = np.linspace(0, 1000, 100)
  5. y = np.linspace(0, 1000, 90)
  6. Y, X = np.meshgrid(x, y)
  7. # 创建速度场
  8. vx = -1 - X**2 + Y
  9. vy = 1 + X - Y**2
  10. # 种子点
  11. J = 5
  12. I = 14
  13. # 插值速度
  14. from scipy.interpolate import RegularGridInterpolator
  15. X, Y = np.meshgrid(x, y)
  16. fx = RegularGridInterpolator((y, x), vx, bounds_error=False, fill_value=None)
  17. fy = RegularGridInterpolator((y, x), vy, bounds_error=False, fill_value=None)
  18. # 定义要积分的速度函数:
  19. def f(t, y):
  20. return np.squeeze([fy(y), fx(y)])
  21. # 解决一个种子点
  22. from scipy.integrate import solve_ivp
  23. sol = solve_ivp(f, [0, 100], [J, I], t_eval=np.arange(0, 100, 1))

请注意,这只是代码的翻译部分,不包括问题部分。

英文:

Someone I work with came with a solution, I am not taking credit for the answer. But considering there is no streamline algorithm for Python, hopefully that will be useful to someone here:

  1. import numpy as np
  2. # Create coordinates
  3. x = np.linspace(0, 1000, 100)
  4. y = np.linspace(0, 1000, 90)
  5. Y, X = np.meshgrid(x, y)
  6. # Create velocity fields
  7. vx = -1 - X**2 + Y
  8. vy = 1 + X - Y**2
  9. # Seed point
  10. J = 5
  11. I = 14
  12. # Interpolate the velocities
  13. from scipy.interpolate import RegularGridInterpolator
  14. X, Y = np.meshgrid(x,y)
  15. fx = RegularGridInterpolator((y, x), vx, bounds_error=False, fill_value=None)
  16. fy = RegularGridInterpolator((y, x), vy, bounds_error=False, fill_value=None)
  17. # define the velocity function to be integrated:
  18. def f(t, y):
  19. return np.squeeze([fy(y), fx(y)])
  20. # Solve for a seed point
  21. from scipy.integrate import solve_ivp
  22. sol = solve_ivp(f, [0, 100], [J,I], t_eval=np.arange(0,100,1))

答案2

得分: 0

matplotlib.pyplot.streamplot的代码可以在GitHub上查看。我需要类似的功能,但需要访问构成流线的点,而不仅仅是绘图。我稍微修改了matplotlib.pyplot.streamplot使用的代码,删除了与绘图相关的语句,并将返回的变量设置为流线中的点序列。如在matplotlib代码中的原始注释中所提到的,streamline方法使用二阶Runge-Kutta算法和自适应步长来生成流线。修改后的代码如下,我将函数重命名为streamplot2。我没有删除我不使用的参数,但当然可以删除以使解决方案更清晰。

可以按以下方式调用该函数:

  1. streams = streamplot2(X, Y, dx, dy, broken_streamlines=False, integration_direction="forward", start_points=start_points_coords)

其中start_points_coords是一个形状为Nx2的数组,N等于起始点的数量(可以为1条流线)。

该函数返回一个流线列表,其中每条流线是一个形状为Mx2的数组,其中M是流线中的点数。

然后可以按以下方式绘制流线(没有额外的装饰):

  1. for j in np.arange(len(streams)):
  2. plt.scatter(streams[j][:, 0], streams[j][:, 1], s=1)

以下是完整的修改后的代码:

  1. #(代码已截断,只包含上面提到的相关部分)

这段代码提供了一个修改后的streamplot函数,允许你生成流线并访问流线中的点,以及如何调用和绘制流线。如果需要完整的代码,请从GitHub或相关资源中获取。

英文:

The code of matplotlib.pyplot.streamplot can be checked on GitHub. I needed a similar functionality but needed access to the points that made up the streamlines, not just the plot. I slightly modified the code used by matplotlib.pyplot.streamplot by removing the statements related to plotting and setting the variable returned to the sequence of points in the streamlines. As mentioned in the original comments in the matplotlib code, the streamline method uses the second order Runge-Kutta algorithm with adaptive step size to generate the streamlines. The modified code is below, I have renamed the function streamplot2. I didn't remove the parameters I don't use, but of course it can be done to make the solution cleaner.

You can call the function as follows:

  1. streams = streamplot2(X, Y, dx, dy, broken_streamlines=False, integration_direction = "forward", start_points = start_points_coords)

where start_points_coords is the array with shape Nx2, with N equal to the number of starting points (it can be 1 for 1 streamline).

The function returns a list of streamline where each streamline is an array with shape Mx2, where M is the number of points in the streamline.

You can then plot the streamlines in the following way (no frills)

  1. for j in np.arange(len(streams)):
  2. plt.scatter(streams[j][:,0],streams[j][:,1],s = 1)

Full modified code below:

  1. """
  2. Streamlines for 2D vector fields.
  3. """
  4. import numpy as np
  5. import matplotlib as mpl
  6. from matplotlib import _api, cm, patches
  7. import matplotlib.colors as mcolors
  8. import matplotlib.collections as mcollections
  9. import matplotlib.lines as mlines
  10. __all__ = ['streamplot']
  11. def streamplot2(x, y, u, v, density=1, linewidth=None, color=None,
  12. cmap=None, norm=None, arrowsize=1, arrowstyle='-|>',
  13. minlength=0.1, transform=None, zorder=None, start_points=None,
  14. maxlength=4.0, integration_direction='both',
  15. broken_streamlines=True):
  16. """
  17. Draw streamlines of a vector flow.
  18. Parameters
  19. ----------
  20. x, y : 1D/2D arrays
  21. Evenly spaced strictly increasing arrays to make a grid. If 2D, all
  22. rows of *x* must be equal and all columns of *y* must be equal; i.e.,
  23. they must be as if generated by ``np.meshgrid(x_1d, y_1d)``.
  24. u, v : 2D arrays
  25. *x* and *y*-velocities. The number of rows and columns must match
  26. the length of *y* and *x*, respectively.
  27. density : float or (float, float)
  28. Controls the closeness of streamlines. When ``density = 1``, the domain
  29. is divided into a 30x30 grid. *density* linearly scales this grid.
  30. Each cell in the grid can have, at most, one traversing streamline.
  31. For different densities in each direction, use a tuple
  32. (density_x, density_y).
  33. linewidth : float or 2D array
  34. The width of the streamlines. With a 2D array the line width can be
  35. varied across the grid. The array must have the same shape as *u*
  36. and *v*.
  37. color : color or 2D array
  38. The streamline color. If given an array, its values are converted to
  39. colors using *cmap* and *norm*. The array must have the same shape
  40. as *u* and *v*.
  41. cmap, norm
  42. Data normalization and colormapping parameters for *color*; only used
  43. if *color* is an array of floats. See `~.Axes.imshow` for a detailed
  44. description.
  45. arrowsize : float
  46. Scaling factor for the arrow size.
  47. arrowstyle : str
  48. Arrow style specification.
  49. See `~matplotlib.patches.FancyArrowPatch`.
  50. minlength : float
  51. Minimum length of streamline in axes coordinates.
  52. start_points : (N, 2) array
  53. Coordinates of starting points for the streamlines in data coordinates
  54. (the same coordinates as the *x* and *y* arrays).
  55. zorder : float
  56. The zorder of the streamlines and arrows.
  57. Artists with lower zorder values are drawn first.
  58. maxlength : float
  59. Maximum length of streamline in axes coordinates.
  60. integration_direction : {'forward', 'backward', 'both'}, default: 'both'
  61. Integrate the streamline in forward, backward or both directions.
  62. data : indexable object, optional
  63. DATA_PARAMETER_PLACEHOLDER
  64. broken_streamlines : boolean, default: True
  65. If False, forces streamlines to continue until they
  66. leave the plot domain. If True, they may be terminated if they
  67. come too close to another streamline.
  68. Returns
  69. -------
  70. StreamplotSet
  71. Container object with attributes
  72. - ``lines``: `.LineCollection` of streamlines
  73. - ``arrows``: `.PatchCollection` containing `.FancyArrowPatch`
  74. objects representing the arrows half-way along streamlines.
  75. This container will probably change in the future to allow changes
  76. to the colormap, alpha, etc. for both lines and arrows, but these
  77. changes should be backward compatible.
  78. """
  79. grid = Grid(x, y)
  80. mask = StreamMask(density)
  81. dmap = DomainMap(grid, mask)
  82. u = np.ma.masked_invalid(u)
  83. v = np.ma.masked_invalid(v)
  84. integrate = _get_integrator(u, v, dmap, minlength, maxlength,
  85. integration_direction)
  86. trajectories = []
  87. if start_points is None:
  88. for xm, ym in _gen_starting_points(mask.shape):
  89. if mask[ym, xm] == 0:
  90. xg, yg = dmap.mask2grid(xm, ym)
  91. t = integrate(xg, yg, broken_streamlines)
  92. if t is not None:
  93. trajectories.append(t)
  94. else:
  95. sp2 = np.asanyarray(start_points, dtype=float).copy()
  96. # Check if start_points are outside the data boundaries
  97. for xs, ys in sp2:
  98. if not (grid.x_origin <= xs <= grid.x_origin + grid.width and
  99. grid.y_origin <= ys <= grid.y_origin + grid.height):
  100. raise ValueError(f"Starting point ({xs}, {ys}) outside of "
  101. "data boundaries")
  102. # Convert start_points from data to array coords
  103. # Shift the seed points from the bottom left of the data so that
  104. # data2grid works properly.
  105. sp2[:, 0] -= grid.x_origin
  106. sp2[:, 1] -= grid.y_origin
  107. for xs, ys in sp2:
  108. xg, yg = dmap.data2grid(xs, ys)
  109. # Floating point issues can cause xg, yg to be slightly out of
  110. # bounds for xs, ys on the upper boundaries. Because we have
  111. # already checked that the starting points are within the original
  112. # grid, clip the xg, yg to the grid to work around this issue
  113. xg = np.clip(xg, 0, grid.nx - 1)
  114. yg = np.clip(yg, 0, grid.ny - 1)
  115. t = integrate(xg, yg, broken_streamlines)
  116. if t is not None:
  117. trajectories.append(t)
  118. streamlines = []
  119. for t in trajectories:
  120. tgx, tgy = t.T
  121. # Rescale from grid-coordinates to data-coordinates.
  122. tx, ty = dmap.grid2data(tgx, tgy)
  123. tx += grid.x_origin
  124. ty += grid.y_origin
  125. points = np.transpose([tx, ty])
  126. streamlines.append(points)
  127. return streamlines
  128. class StreamplotSet:
  129. def __init__(self, lines, arrows):
  130. self.lines = lines
  131. self.arrows = arrows
  132. # Coordinate definitions
  133. # ========================
  134. class DomainMap:
  135. """
  136. Map representing different coordinate systems.
  137. Coordinate definitions:
  138. * axes-coordinates goes from 0 to 1 in the domain.
  139. * data-coordinates are specified by the input x-y coordinates.
  140. * grid-coordinates goes from 0 to N and 0 to M for an N x M grid,
  141. where N and M match the shape of the input data.
  142. * mask-coordinates goes from 0 to N and 0 to M for an N x M mask,
  143. where N and M are user-specified to control the density of streamlines.
  144. This class also has methods for adding trajectories to the StreamMask.
  145. Before adding a trajectory, run `start_trajectory` to keep track of regions
  146. crossed by a given trajectory. Later, if you decide the trajectory is bad
  147. (e.g., if the trajectory is very short) just call `undo_trajectory`.
  148. """
  149. def __init__(self, grid, mask):
  150. self.grid = grid
  151. self.mask = mask
  152. # Constants for conversion between grid- and mask-coordinates
  153. self.x_grid2mask = (mask.nx - 1) / (grid.nx - 1)
  154. self.y_grid2mask = (mask.ny - 1) / (grid.ny - 1)
  155. self.x_mask2grid = 1. / self.x_grid2mask
  156. self.y_mask2grid = 1. / self.y_grid2mask
  157. self.x_data2grid = 1. / grid.dx
  158. self.y_data2grid = 1. / grid.dy
  159. def grid2mask(self, xi, yi):
  160. """Return nearest space in mask-coords from given grid-coords."""
  161. return round(xi * self.x_grid2mask), round(yi * self.y_grid2mask)
  162. def mask2grid(self, xm, ym):
  163. return xm * self.x_mask2grid, ym * self.y_mask2grid
  164. def data2grid(self, xd, yd):
  165. return xd * self.x_data2grid, yd * self.y_data2grid
  166. def grid2data(self, xg, yg):
  167. return xg / self.x_data2grid, yg / self.y_data2grid
  168. def start_trajectory(self, xg, yg, broken_streamlines=True):
  169. xm, ym = self.grid2mask(xg, yg)
  170. self.mask._start_trajectory(xm, ym, broken_streamlines)
  171. def reset_start_point(self, xg, yg):
  172. xm, ym = self.grid2mask(xg, yg)
  173. self.mask._current_xy = (xm, ym)
  174. def update_trajectory(self, xg, yg, broken_streamlines=True):
  175. if not self.grid.within_grid(xg, yg):
  176. raise InvalidIndexError
  177. xm, ym = self.grid2mask(xg, yg)
  178. self.mask._update_trajectory(xm, ym, broken_streamlines)
  179. def undo_trajectory(self):
  180. self.mask._undo_trajectory()
  181. class Grid:
  182. """Grid of data."""
  183. def __init__(self, x, y):
  184. if np.ndim(x) == 1:
  185. pass
  186. elif np.ndim(x) == 2:
  187. x_row = x[0]
  188. if not np.allclose(x_row, x):
  189. raise ValueError("The rows of 'x' must be equal")
  190. x = x_row
  191. else:
  192. raise ValueError("'x' can have at maximum 2 dimensions")
  193. if np.ndim(y) == 1:
  194. pass
  195. elif np.ndim(y) == 2:
  196. yt = np.transpose(y) # Also works for nested lists.
  197. y_col = yt[0]
  198. if not np.allclose(y_col, yt):
  199. raise ValueError("The columns of 'y' must be equal")
  200. y = y_col
  201. else:
  202. raise ValueError("'y' can have at maximum 2 dimensions")
  203. if not (np.diff(x) > 0).all():
  204. raise ValueError("'x' must be strictly increasing")
  205. if not (np.diff(y) > 0).all():
  206. raise ValueError("'y' must be strictly increasing")
  207. self.nx = len(x)
  208. self.ny = len(y)
  209. self.dx = x[1] - x[0]
  210. self.dy = y[1] - y[0]
  211. self.x_origin = x[0]
  212. self.y_origin = y[0]
  213. self.width = x[-1] - x[0]
  214. self.height = y[-1] - y[0]
  215. if not np.allclose(np.diff(x), self.width / (self.nx - 1)):
  216. raise ValueError("'x' values must be equally spaced")
  217. if not np.allclose(np.diff(y), self.height / (self.ny - 1)):
  218. raise ValueError("'y' values must be equally spaced")
  219. @property
  220. def shape(self):
  221. return self.ny, self.nx
  222. def within_grid(self, xi, yi):
  223. """Return whether (*xi*, *yi*) is a valid index of the grid."""
  224. # Note that xi/yi can be floats; so, for example, we can't simply check
  225. # `xi < self.nx` since *xi* can be `self.nx - 1 < xi < self.nx`
  226. return 0 <= xi <= self.nx - 1 and 0 <= yi <= self.ny - 1
  227. class StreamMask:
  228. """
  229. Mask to keep track of discrete regions crossed by streamlines.
  230. The resolution of this grid determines the approximate spacing between
  231. trajectories. Streamlines are only allowed to pass through zeroed cells:
  232. When a streamline enters a cell, that cell is set to 1, and no new
  233. streamlines are allowed to enter.
  234. """
  235. def __init__(self, density):
  236. try:
  237. self.nx, self.ny = (30 * np.broadcast_to(density, 2)).astype(int)
  238. except ValueError as err:
  239. raise ValueError("'density' must be a scalar or be of length "
  240. "2") from err
  241. if self.nx < 0 or self.ny < 0:
  242. raise ValueError("'density' must be positive")
  243. self._mask = np.zeros((self.ny, self.nx))
  244. self.shape = self._mask.shape
  245. self._current_xy = None
  246. def __getitem__(self, args):
  247. return self._mask[args]
  248. def _start_trajectory(self, xm, ym, broken_streamlines=True):
  249. """Start recording streamline trajectory"""
  250. self._traj = []
  251. self._update_trajectory(xm, ym, broken_streamlines)
  252. def _undo_trajectory(self):
  253. """Remove current trajectory from mask"""
  254. for t in self._traj:
  255. self._mask[t] = 0
  256. def _update_trajectory(self, xm, ym, broken_streamlines=True):
  257. """
  258. Update current trajectory position in mask.
  259. If the new position has already been filled, raise `InvalidIndexError`.
  260. """
  261. if self._current_xy != (xm, ym):
  262. if self[ym, xm] == 0:
  263. self._traj.append((ym, xm))
  264. self._mask[ym, xm] = 1
  265. self._current_xy = (xm, ym)
  266. else:
  267. if broken_streamlines:
  268. raise InvalidIndexError
  269. else:
  270. pass
  271. class InvalidIndexError(Exception):
  272. pass
  273. class TerminateTrajectory(Exception):
  274. pass
  275. # Integrator definitions
  276. # =======================
  277. def _get_integrator(u, v, dmap, minlength, maxlength, integration_direction):
  278. # rescale velocity onto grid-coordinates for integrations.
  279. u, v = dmap.data2grid(u, v)
  280. # speed (path length) will be in axes-coordinates
  281. u_ax = u / (dmap.grid.nx - 1)
  282. v_ax = v / (dmap.grid.ny - 1)
  283. speed = np.ma.sqrt(u_ax ** 2 + v_ax ** 2)
  284. #print("speed", speed)
  285. def forward_time(xi, yi):
  286. if not dmap.grid.within_grid(xi, yi):
  287. raise OutOfBounds
  288. #print("xi",xi)
  289. #print("yi",yi)
  290. ds_dt = interpgrid(speed, xi, yi)
  291. if ds_dt == 0:
  292. #print("ds_dt", ds_dt)
  293. raise TerminateTrajectory()
  294. dt_ds = 1. / ds_dt
  295. ui = interpgrid(u, xi, yi)
  296. vi = interpgrid(v, xi, yi)
  297. return ui * dt_ds, vi * dt_ds
  298. def backward_time(xi, yi):
  299. dxi, dyi = forward_time(xi, yi)
  300. return -dxi, -dyi
  301. def integrate(x0, y0, broken_streamlines=True):
  302. """
  303. Return x, y grid-coordinates of trajectory based on starting point.
  304. Integrate both forward and backward in time from starting point in
  305. grid coordinates.
  306. Integration is terminated when a trajectory reaches a domain boundary
  307. or when it crosses into an already occupied cell in the StreamMask. The
  308. resulting trajectory is None if it is shorter than `minlength`.
  309. """
  310. stotal, xy_traj = 0., []
  311. try:
  312. dmap.start_trajectory(x0, y0, broken_streamlines)
  313. except InvalidIndexError:
  314. return None
  315. if integration_direction in ['both', 'backward']:
  316. s, xyt = _integrate_rk12(x0, y0, dmap, backward_time, maxlength,
  317. broken_streamlines)
  318. stotal += s
  319. xy_traj += xyt[::-1]
  320. if integration_direction in ['both', 'forward']:
  321. dmap.reset_start_point(x0, y0)
  322. s, xyt = _integrate_rk12(x0, y0, dmap, forward_time, maxlength,
  323. broken_streamlines)
  324. stotal += s
  325. xy_traj += xyt[1:]
  326. if stotal > minlength:
  327. return np.broadcast_arrays(xy_traj, np.empty((1, 2)))[0]
  328. else: # reject short trajectories
  329. dmap.undo_trajectory()
  330. return None
  331. return integrate
  332. class OutOfBounds(IndexError):
  333. pass
  334. def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True):
  335. """
  336. 2nd-order Runge-Kutta algorithm with adaptive step size.
  337. This method is also referred to as the improved Euler's method, or Heun's
  338. method. This method is favored over higher-order methods because:
  339. 1. To get decent looking trajectories and to sample every mask cell
  340. on the trajectory we need a small timestep, so a lower order
  341. solver doesn't hurt us unless the data is *very* high resolution.
  342. In fact, for cases where the user inputs
  343. data smaller or of similar grid size to the mask grid, the higher
  344. order corrections are negligible because of the very fast linear
  345. interpolation used in `interpgrid`.
  346. 2. For high resolution input data (i.e. beyond the mask
  347. resolution), we must reduce the timestep. Therefore, an adaptive
  348. timestep is more suited to the problem as this would be very hard
  349. to judge automatically otherwise.
  350. This integrator is about 1.5 - 2x as fast as RK4 and RK45 solvers (using
  351. similar Python implementations) in most setups.
  352. """
  353. # This error is below that needed to match the RK4 integrator. It
  354. # is set for visual reasons -- too low and corners start
  355. # appearing ugly and jagged. Can be tuned.
  356. maxerror = 0.003
  357. # This limit is important (for all integrators) to avoid the
  358. # trajectory skipping some mask cells. We could relax this
  359. # condition if we use the code which is commented out below to
  360. # increment the location gradually. However, due to the efficient
  361. # nature of the interpolation, this doesn't boost speed by much
  362. # for quite a bit of complexity.
  363. #print("dmap.mask.nx", dmap.mask.nx)
  364. #maxds = min(1. / dmap.mask.nx, 1. / dmap.mask.ny, 0.1)
  365. maxds = 0.01
  366. ds = maxds
  367. print("ds", ds)
  368. stotal = 0
  369. xi = x0
  370. yi = y0
  371. xyf_traj = []
  372. while True:
  373. try:
  374. if dmap.grid.within_grid(xi, yi):
  375. xyf_traj.append((xi, yi))
  376. else:
  377. raise OutOfBounds
  378. # Compute the two intermediate gradients.
  379. # f should raise OutOfBounds if the locations given are
  380. # outside the grid.
  381. k1x, k1y = f(xi, yi)
  382. k2x, k2y = f(xi + ds * k1x, yi + ds * k1y)
  383. except OutOfBounds:
  384. # Out of the domain during this step.
  385. # Take an Euler step to the boundary to improve neatness
  386. # unless the trajectory is currently empty.
  387. if xyf_traj:
  388. ds, xyf_traj = _euler_step(xyf_traj, dmap, f)
  389. stotal += ds
  390. break
  391. except TerminateTrajectory:
  392. break
  393. dx1 = ds * k1x
  394. dy1 = ds * k1y
  395. dx2 = ds * 0.5 * (k1x + k2x)
  396. dy2 = ds * 0.5 * (k1y + k2y)
  397. ny, nx = dmap.grid.shape
  398. # Error is normalized to the axes coordinates
  399. error = np.hypot((dx2 - dx1) / (nx - 1), (dy2 - dy1) / (ny - 1))
  400. print("error",error)
  401. # Only save step if within error tolerance
  402. if error < maxerror:
  403. xi += dx2
  404. yi += dy2
  405. try:
  406. dmap.update_trajectory(xi, yi, broken_streamlines)
  407. except InvalidIndexError:
  408. break
  409. if stotal + ds > maxlength:
  410. break
  411. stotal += ds
  412. # recalculate stepsize based on step error
  413. if error == 0:
  414. ds = maxds
  415. else:
  416. ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5)
  417. return stotal, xyf_traj
  418. def _euler_step(xyf_traj, dmap, f):
  419. """Simple Euler integration step that extends streamline to boundary."""
  420. ny, nx = dmap.grid.shape
  421. xi, yi = xyf_traj[-1]
  422. cx, cy = f(xi, yi)
  423. if cx == 0:
  424. dsx = np.inf
  425. elif cx < 0:
  426. dsx = xi / -cx
  427. else:
  428. dsx = (nx - 1 - xi) / cx
  429. if cy == 0:
  430. dsy = np.inf
  431. elif cy < 0:
  432. dsy = yi / -cy
  433. else:
  434. dsy = (ny - 1 - yi) / cy
  435. ds = min(dsx, dsy)
  436. xyf_traj.append((xi + cx * ds, yi + cy * ds))
  437. return ds, xyf_traj
  438. # Utility functions
  439. # ========================
  440. def interpgrid(a, xi, yi):
  441. """Fast 2D, linear interpolation on an integer grid"""
  442. Ny, Nx = np.shape(a)
  443. if isinstance(xi, np.ndarray):
  444. x = xi.astype(int)
  445. y = yi.astype(int)
  446. # Check that xn, yn don't exceed max index
  447. xn = np.clip(x + 1, 0, Nx - 1)
  448. yn = np.clip(y + 1, 0, Ny - 1)
  449. else:
  450. x = int(xi)
  451. y = int(yi)
  452. # conditional is faster than clipping for integers
  453. if x == (Nx - 1):
  454. xn = x
  455. else:
  456. xn = x + 1
  457. if y == (Ny - 1):
  458. yn = y
  459. else:
  460. yn = y + 1
  461. a00 = a[y, x]
  462. a01 = a[y, xn]
  463. a10 = a[yn, x]
  464. a11 = a[yn, xn]
  465. xt = xi - x
  466. yt = yi - y
  467. a0 = a00 * (1 - xt) + a01 * xt
  468. a1 = a10 * (1 - xt) + a11 * xt
  469. ai = a0 * (1 - yt) + a1 * yt
  470. if not isinstance(xi, np.ndarray):
  471. if np.ma.is_masked(ai):
  472. raise TerminateTrajectory
  473. return ai
  474. def _gen_starting_points(shape):
  475. """
  476. Yield starting points for streamlines.
  477. Trying points on the boundary first gives higher quality streamlines.
  478. This algorithm starts with a point on the mask corner and spirals inward.
  479. This algorithm is inefficient, but fast compared to rest of streamplot.
  480. """
  481. ny, nx = shape
  482. xfirst = 0
  483. yfirst = 1
  484. xlast = nx - 1
  485. ylast = ny - 1
  486. x, y = 0, 0
  487. direction = 'right'
  488. for i in range(nx * ny):
  489. yield x, y
  490. if direction == 'right':
  491. x += 1
  492. if x >= xlast:
  493. xlast -= 1
  494. direction = 'up'
  495. elif direction == 'up':
  496. y += 1
  497. if y >= ylast:
  498. ylast -= 1
  499. direction = 'left'
  500. elif direction == 'left':
  501. x -= 1
  502. if x <= xfirst:
  503. xfirst += 1
  504. direction = 'down'
  505. elif direction == 'down':
  506. y -= 1
  507. if y <= yfirst:
  508. yfirst += 1
  509. direction = 'right'

huangapple
  • 本文由 发表于 2023年2月10日 11:22:45
  • 转载请务必保留本文链接:https://go.coder-hub.com/75406627.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定