获取从源节点到所有节点的最短距离优化

huangapple go评论78阅读模式
英文:

Get shortest distance from src to all nodes optimization

问题

以下是您要翻译的内容:

我有一个输入作为[][]edges。数组的列长度为2。因此,2D数组的每一行都有2个元素。每个元素都是一个顶点。它是双向的,即我们可以说边是双向的。因此,如果我们遍历这个2D数组,我们可以说我们有一个无向图。

我试图找到从一个特定节点到所有节点的最短距离。在这种情况下,从节点0到所有现有节点。

我有一个有效的代码,但我认为我在重新计算我想要避免的东西。我一次又一次地调用函数computeDistPerNode(m,0,key);,我确信它正在重新计算从0到之前调用中看到的节点的距离。我无法优化它并利用过去的计算。我该如何做?

以下是没有优化的工作代码

  1. public Map<Integer, List<Integer>> createUnDirectedGraph(int [][]edges) {
  2. Map<Integer, List<Integer>> m = new HashMap<>();
  3. for(var i = 0; i<edges.length; i++) {
  4. m.put(edges[i][0], new ArrayList<>());
  5. m.put(edges[i][1], new ArrayList<>());
  6. }
  7. for(var edge:edges) {
  8. var v1 = edge[0];
  9. var v2 = edge[1];
  10. m.get(v1).add(v2);
  11. m.get(v2).add(v1);
  12. }
  13. return m;
  14. }
  15. public int[] getShortestDistances(Map<Integer, List<Integer>> m) {
  16. int distance[] = new int[m.size()];
  17. for(Integer key:m.keySet()) {
  18. var d = computeDistPerNode(m,0,key);
  19. distance[key] = d;
  20. }
  21. return distance;
  22. }
  23. public int computeDistPerNode(Map<Integer, List<Integer>> m, int src, int dest) {
  24. Queue<Integer> q = new LinkedList<>();
  25. Integer dist[] = new Integer[m.size()];
  26. Set<Integer> visited = new HashSet<>();
  27. Arrays.fill(dist, Integer.MAX_VALUE);
  28. dist[src] = 0;
  29. q.add(src);
  30. while(!q.isEmpty()) {
  31. var currNode = q.poll();
  32. if(visited.contains(currNode)) continue;
  33. visited.add(currNode);
  34. if(currNode == dest) {
  35. return dist[dest];
  36. }
  37. for(var child: m.get(currNode)) {
  38. if (visited.contains(child)) {
  39. continue;
  40. }
  41. q.offer(child);
  42. var newDist = 1 + dist[currNode];
  43. if(newDist<dist[child]) {
  44. dist[child] = newDist;
  45. }
  46. }
  47. }
  48. return -1;
  49. }
  50. public int[][] getsample() {
  51. int [][] edges = {
  52. {0,1},
  53. {0,2},
  54. {1,4},
  55. {2,3},
  56. {4,3},
  57. {0,4},
  58. };
  59. return edges;
  60. }
英文:

I have an input as [][]edges. The col length of array is 2. Each row of the 2D array hence has 2 elements. Each element is a vertex. And it is bidirectional i.e we can say the edge is in both directions. Hence if we go through this 2D array, we can say we have an undirected graph.

I am trying to find the shortest distance from one particular node to all nodes. In this case say from node 0 to all the nodes that exist.

I have code that works but I think I am re-computing things which I want to avoid. I call the function computeDistPerNode(m,0,key); again and again and I am sure it is doing re-computation of distance from 0 to nodes that it has seen in prior calls. I am unable to optimize it and leverage the past computation. How do I do it?

Here is the working code without optimization

  1. public Map&lt;Integer, List&lt;Integer&gt;&gt; createUnDirectedGraph(int [][]edges) {
  2. Map&lt;Integer, List&lt;Integer&gt;&gt; m = new HashMap&lt;&gt;();
  3. for(var i = 0; i&lt;edges.length; i++) {
  4. m.put(edges[i][0], new ArrayList&lt;&gt;());
  5. m.put(edges[i][1], new ArrayList&lt;&gt;());
  6. }
  7. for(var edge:edges) {
  8. var v1 = edge[0];
  9. var v2 = edge[1];
  10. m.get(v1).add(v2);
  11. m.get(v2).add(v1);
  12. }
  13. return m;
  14. }
  15. public int[] getShortestDistances(Map&lt;Integer, List&lt;Integer&gt;&gt; m) {
  16. int distance[] = new int[m.size()];
  17. for(Integer key:m.keySet()) {
  18. var d = computeDistPerNode(m,0,key);
  19. distance[key] = d;
  20. }
  21. return distance;
  22. }
  23. public int computeDistPerNode(Map&lt;Integer, List&lt;Integer&gt;&gt; m, int src, int dest) {
  24. Queue&lt;Integer&gt; q = new LinkedList&lt;&gt;();
  25. Integer dist[] = new Integer[m.size()];
  26. Set&lt;Integer&gt; visited = new HashSet&lt;&gt;();
  27. Arrays.fill(dist, Integer.MAX_VALUE);
  28. dist[src] = 0;
  29. q.add(src);
  30. while(!q.isEmpty()) {
  31. var currNode = q.poll();
  32. if(visited.contains(currNode)) continue;
  33. visited.add(currNode);
  34. if(currNode == dest) {
  35. return dist[dest];
  36. }
  37. for(var child: m.get(currNode)) {
  38. if (visited.contains(child)) {
  39. continue;
  40. }
  41. q.offer(child);
  42. var newDist = 1 + dist[currNode];
  43. if(newDist&lt;dist[child]) {
  44. dist[child] = newDist;
  45. }
  46. }
  47. }
  48. return -1;
  49. }
  50. public int[][] getsample() {
  51. int [][] edges = {
  52. {0,1},
  53. {0,2},
  54. {1,4},
  55. {2,3},
  56. {4,3},
  57. {0,4},
  58. };
  59. return edges;
  60. }

答案1

得分: 3

  1. 你可以一次性计算从源节点到所有其他节点的距离。
  2. 方法 `int computeDistPerNode(Map&lt;Integer, List&lt;Integer&gt;&gt; m, int src, int dest)` 在到达目标节点时立即返回。将其改为在队列为空时返回 `dist` 数组。以下是修改后的方法:
  3. ```java
  4. public Integer[] computeDistFromSource(Map&lt;Integer, List&lt;Integer&gt;&gt; m, int src) {
  5. Set&lt;Integer&gt; visited = new HashSet&lt;&gt;();
  6. Integer[] dist = new Integer[m.size()];
  7. Arrays.fill(dist, Integer.MAX_VALUE);
  8. dist[src] = 0;
  9. Queue&lt;Integer&gt; q = new LinkedList&lt;&gt;();
  10. visited.add(src); // 在此处标记源节点为已访问
  11. q.add(src);
  12. while(!q.isEmpty()) {
  13. var currNode = q.poll();
  14. for(var child: m.get(currNode)) {
  15. if (!visited.contains(child)) {
  16. visited.add(child);
  17. q.offer(child);
  18. dist[child] = 1 + dist[currNode];
  19. }
  20. }
  21. }
  22. return dist;
  23. }

改进

如果稍微调整代码,可以避免三次 if 调用。这将导致代码更干净、更易读。

  1. public Integer[] computeDistFromSource(Map&lt;Integer, List&lt;Integer&gt;&gt; m, int src) {
  2. Set&lt;Integer&gt; visited = new HashSet&lt;&gt;();
  3. Integer[] dist = new Integer[m.size()];
  4. Arrays.fill(dist, Integer.MAX_VALUE);
  5. dist[src] = 0;
  6. Queue&lt;Integer&gt; q = new LinkedList&lt;&gt;();
  7. visited.add(src); // 在此处标记源节点为已访问
  8. q.add(src);
  9. while(!q.isEmpty()) {
  10. var currNode = q.poll();
  11. for(var child: m.get(currNode)) {
  12. if (!visited.contains(child)) {
  13. visited.add(child);
  14. q.offer(child);
  15. dist[child] = 1 + dist[currNode];
  16. }
  17. }
  18. }
  19. return dist;
  20. }

分析

所使用的算法是广度优先搜索。根据Wikipedia

> 时间复杂度可以表示为 O(|V| + |E|),因为在最坏情况下将探索每个顶点和每条边。|V| 是顶点数,|E| 是图中的边数。请注意,O(|E|) 可能在 O(1)O(|V|^2) 之间变化,这取决于输入图的稀疏程度。

问题

> 你能帮我理解为什么如果不进行检查,可能会导致 newDist 的较大值不会写入当前的 dist[child] 吗?我认为原因是由于BFS/使用队列的性质,当一个未访问的节点被拉出时,子节点会首先被访问,因此不需要进行检查?

在你的代码中,if(newDist &lt; dist[child]) 条件是必要的,以确保代码的正确工作。在优化后的代码中,这不是必要的。原因在于 visited.add(child) 的位置。在你的代码中,该检查发生在从队列中获取节点之后。在优化后的代码中,这在发现节点后立即发生。这造成了很大的差异。

考虑你的输入图

  1. 0 ------- 1
  2. |\ |
  3. | \ |
  4. | \ |
  5. | \ |
  6. | \|
  7. | 4
  8. | |
  9. | |
  10. | |
  11. 2 ------- 3
你的代码的工作原理

源顶点是 0。在 while (!q.isEmpty() 循环开始之前,我们将其添加到队列中。

while 循环中,我们移除 0 并将其标记为已访问。我们按顺序探索其邻居 1、2 和 4。我们将它们的距离更新为 1,并将它们全部添加到队列中。但是它们中没有一个被标记为已访问。

现在我们回到 while 循环的开始,获取 1 并将其标记为已访问。然后再次探索其邻居 0 和 4。我们不会更新 0 的距离,因为它已经被访问过了。我们再次将 4 添加到队列中,即使它已经是队列的一部分。 我们再次将相同的节点添加到队列中,这本身就不是一个好事情。请注意,如果没有 if(newDist &lt; dist[child]) 条件,它的距离将被错误地更新为 2。

优化后代码的工作原理

源顶点是 0。在 while (!q.isEmpty() 循环开始之前,我们将其添加到队列中并在此处标记为已访问。

while 循环中,我们移除 0。我们按顺序探索其邻居 1、2 和 4。我们将它们的距离更新为 1,并将它们全部添加到队列中并标记为已访问。因此它们的距离永远不会再次被更新。

现在我们回到 while 循环的开始,获取 1 并再次探索其邻居 0 和 4。我们不会更新 0 和 1 的距离,因为它们都已经被访问过了。节点 4 也不会被添加到队列中两次。

  1. <details>
  2. <summary>英文:</summary>
  3. You can calculate distance from the source node to all the other nodes in one go.
  4. The method `int computeDistPerNode(Map&lt;Integer, List&lt;Integer&gt;&gt; m, int src, int dest)` returns as soon as you reach the destination node. Change that to return the `dist` array when the queue is empty. Here is your modified method

public Integer[] computeDistFromSource(Map<Integer, List<Integer>> m, int src) {
Set<Integer> visited = new HashSet<>();

  1. Integer[] dist = new Integer[m.size()];
  2. Arrays.fill(dist, Integer.MAX_VALUE);
  3. dist[src] = 0;
  4. Queue&lt;Integer&gt; q = new LinkedList&lt;&gt;();
  5. q.add(src);
  6. while(!q.isEmpty()) {
  7. var currNode = q.poll();
  8. if(visited.contains(currNode)) continue;
  9. visited.add(currNode);
  10. for(var child: m.get(currNode)) {
  11. if (visited.contains(child)) continue;
  12. q.offer(child);
  13. var newDist = 1 + dist[currNode];
  14. if(newDist &lt; dist[child]) {
  15. dist[child] = newDist;
  16. }
  17. }
  18. }
  19. return dist;

}

  1. ## Improvements
  2. If you re-position your lines a little, you can avoid three if calls. This results in a more clean and readable code.

public Integer[] computeDistFromSource(Map<Integer, List<Integer>> m, int src) {
Set<Integer> visited = new HashSet<>();

  1. Integer[] dist = new Integer[m.size()];
  2. Arrays.fill(dist, Integer.MAX_VALUE);
  3. dist[src] = 0;
  4. Queue&lt;Integer&gt; q = new LinkedList&lt;&gt;();
  5. visited.add(src); // mark source visited here
  6. q.add(src);
  7. while(!q.isEmpty()) {
  8. var currNode = q.poll();
  9. for(var child: m.get(currNode)) {
  10. if (!visited.contains(child)) {
  11. visited.add(child);
  12. q.offer(child);
  13. dist[child] = 1 + dist[currNode];
  14. }
  15. }
  16. }
  17. return dist;

}

  1. ## Analysis
  2. The algorithm employed is [Breadth-first search](https://en.wikipedia.org/wiki/Breadth-first_search). According to [Wikipedia](https://en.wikipedia.org/wiki/Breadth-first_search#Time_and_space_complexity)
  3. &gt; The time complexity can be expressed as `O(|V| + |E|)`, since every vertex and every edge will be explored in the worst case. `|V|` is the number of vertices and `|E|` is the number of edges in the graph. Note that `O(|E|)` may vary between `O(1)` and `O(|V|^2)`, depending on how sparse the input graph is.
  4. ## Question
  5. &gt; Can you help me understand how a larger value of newDist might not get written in current dist[child] without that check? I think the reason is that a child due to the nature of BFS/using queue will be visited first when an univisited node is pulled out and hence the check is not required?
  6. The `if(newDist &lt; dist[child])` condition is necessary in your code for correct working. It is not required in the optimized code. The reason is the placement of `visited.add(child)`. In your code, that check happens after a node is polled from queue. In the optimized code, this happens immediately after a node is discovered. This creates a big difference.
  7. Consider your input graph

0 ------- 1
|\ |
| \ |
| \ |
| \ |
| |
| 4
| |
| |
| |
2 ------- 3

  1. ##### Working of your code
  2. The source vertex is 0. Before the beginning of the loop `while (!q.isEmpty()` we add it to the queue.
  3. In the while loop, we remove 0 and mark it as visited. We explore its neighbors 1, 2 and 4 in that order. We update their distance to 1 and add all of them to the queue. *However, none of them have been marked as visited.*
  4. Now we go back to the start of the while loop, poll 1, mark it as visited and again explore its neighbors 0 and 4. We do not update the distance of 0 since it is visited. *We add 4 to the queue again even though it is already part of the queue.* We have added the same node in the queue again this is not a good thing in itself. *Notice if there is no `if(newDist &lt; dist[child])` condition, its distance will be updated to 2 which is wrong.*
  5. ##### Working of the optimized code
  6. The source vertex is 0. Before the beginning of the loop `while (!q.isEmpty()` we add it to queue and mark it as visited here only.
  7. In the while loop, we remove 0. We explore its neighbors 1, 2 and 4 in that order. We update their distance to 1 and add all of them to the queue and mark all of them as visited. *Hence their distance can never be updated again.*
  8. Now we go back to the start of the while loop, poll 1 and again explore its neighbors 0 and 4. We do not update the distance of 0 as well as 1 since both of them are visited. The node 4 is also not added to the queue twice.
  9. </details>
  10. # 答案2
  11. **得分**: 1
  12. 如果您使用`min-priority-queue``min-heap`,您可以将算法复杂度降低到`O(|V| * |E|)`,即顶点数和边数的乘积。即使在从@AKSingh的[答案](https://stackoverflow.com/a/76298775/1202808)中改进了您的算法之后,我认为它仍然是`O(|V|^2)`。
  13. 维基百科对[Dijkstra算法](https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm)有一个很好的描述,这是解决使用`min-priority-queue`解决最短路径问题的标准技术。[这里](https://takeuforward.org/data-structure/dijkstras-algorithm-using-priority-queue-g-32/)有一个更教程导向的描述,其中包含很多图示,以可视化算法。
  14. 以下是实现该算法的一些示例代码。我很抱歉它不是用`Java`编写的,但翻译应该很简单。
  15. 示例代码
  16. -----------
  17. ```c++
  18. #include &lt;iostream&gt;
  19. #include &lt;map&gt;
  20. #include &lt;queue&gt;
  21. #include &lt;set&gt;
  22. #include &lt;vector&gt;
  23. using NodePair = std::pair&lt;int,int&gt;;
  24. using NodePairs = std::vector&lt;NodePair&gt;;
  25. using DistanceVertex = std::pair&lt;int, int&gt;;
  26. using MinQueue = std::priority_queue&lt;DistanceVertex,
  27. std::vector&lt;DistanceVertex&gt;,
  28. std::greater&lt;DistanceVertex&gt;&gt;;
  29. int main(int argc, const char *argv[]) {
  30. // 示例问题。我们将图存储为邻接列表
  31. // 使用multimap。
  32. std::multimap&lt;int, int&gt; edges {
  33. { 0, 1 },
  34. { 0, 2 },
  35. { 1, 4 },
  36. { 2, 3 },
  37. { 4, 3 },
  38. { 0, 4 }
  39. };
  40. // 有多少个顶点?
  41. int max_vertex{};
  42. for (auto [a, b] : edges) {
  43. max_vertex = std::max(max_vertex, a);
  44. max_vertex = std::max(max_vertex, b);
  45. }
  46. int number_vertices = max_vertex + 1;
  47. // 将源到每个顶点的距离初始化为MAX_INT。
  48. int source{};
  49. std::vector&lt;int&gt; distance(number_vertices, std::numeric_limits&lt;int&gt;::max());
  50. // 初始化到源的距离和优先队列
  51. MinQueue pq;
  52. distance[source] = 0;
  53. pq.emplace(0, source);
  54. while (!pq.empty()) {
  55. auto [udist, udx] = pq.top();
  56. pq.pop();
  57. // 遍历vdx的所有邻居
  58. auto [begin, end] = edges.equal_range(udx);
  59. for (auto iter = begin; iter != end; ++iter) {
  60. auto vdx = iter-&gt;second, vdist = iter-&gt;first;
  61. // 如果存在更短的路径,则记录它
  62. if (udist + vdist &lt; distance[vdx]) {
  63. distance[vdx] = udist + vdist;
  64. pq.push({udist, vdx});
  65. }
  66. }
  67. }
  68. // distance现在包含源和每个节点之间的最短距离
  69. for (auto i = 0; i &lt; number_vertices; ++i)
  70. std::cout &lt;&lt; distance[i] &lt;&lt; std::endl;
  71. return 0;
  72. }
英文:

If you use a min-priority-queue or min-heap, you can reduce the algorithmic complexity to O(|V| * |E|), i.e. the produce of the number of vertices and number of edges. Even with the improvements to your algorithm from @AKSingh's answer, I think it is still O(|V|^2).

Wikipedia has is a good description of Dijkstra's algorithm which is the standard technique for solving the min-path problem with a min-priority-queue. Here is a more tutorial oriented description with a lot of figures to visualize the algorithm.

The following is some sample code that implements the algorithm. I apologize that it is not in Java, but the translation should be straight forward.

Sample Code

  1. #include &lt;iostream&gt;
  2. #include &lt;map&gt;
  3. #include &lt;queue&gt;
  4. #include &lt;set&gt;
  5. #include &lt;vector&gt;
  6. using NodePair = std::pair&lt;int,int&gt;;
  7. using NodePairs = std::vector&lt;NodePair&gt;;
  8. using DistanceVertex = std::pair&lt;int, int&gt;;
  9. using MinQueue = std::priority_queue&lt;DistanceVertex,
  10. std::vector&lt;DistanceVertex&gt;,
  11. std::greater&lt;DistanceVertex&gt;&gt;;
  12. int main(int argc, const char *argv[]) {
  13. // The sample problem. We store the graph as a adjacency list
  14. // using a multimap.
  15. std::multimap&lt;int, int&gt; edges {
  16. { 0, 1 },
  17. { 0, 2 },
  18. { 1, 4 },
  19. { 2, 3 },
  20. { 4, 3 },
  21. { 0, 4 }
  22. };
  23. // How many vertices?
  24. int max_vertex{};
  25. for (auto [a, b] : edges) {
  26. max_vertex = std::max(max_vertex, a);
  27. max_vertex = std::max(max_vertex, b);
  28. }
  29. int number_vertices = max_vertex + 1;
  30. // Initialize the distance from source to each vertex as MAX_INT.
  31. int source{};
  32. std::vector&lt;int&gt; distance(number_vertices, std::numeric_limits&lt;int&gt;::max());
  33. // Initialize distance to source and priority queue
  34. MinQueue pq;
  35. distance[source] = 0;
  36. pq.emplace(0, source);
  37. while (!pq.empty()) {
  38. auto [udist, udx] = pq.top();
  39. pq.pop();
  40. // Iterate over all neighbors of vdx
  41. auto [begin, end] = edges.equal_range(udx);
  42. for (auto iter = begin; iter != end; ++iter) {
  43. auto vdx = iter-&gt;second, vdist = iter-&gt;first;
  44. // If there is a shorter path, record it
  45. if (udist + vdist &lt; distance[vdx]) {
  46. distance[vdx] = udist + vdist;
  47. pq.push({udist, vdx});
  48. }
  49. }
  50. }
  51. // distance now contains the shortest distance between source and each node
  52. for (auto i = 0; i &lt; number_vertices; ++i)
  53. std::cout &lt;&lt; distance[i] &lt;&lt; std::endl;
  54. return 0;
  55. }

huangapple
  • 本文由 发表于 2023年5月21日 10:48:43
  • 转载请务必保留本文链接:https://go.coder-hub.com/76298093.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定