Trying to implement Dijsktra in Java using priority queue and using hash map for decrease key

huangapple go评论93阅读模式
英文:

Trying to implement Dijsktra in Java using priority queue and using hash map for decrease key

问题

  1. // this has a bug which I dont know how to find
  2. public int networkDelayTime(int[][] times, int N, int K) {
  3. Map<Integer, Map<Integer,Integer>> adjListWithDistance = new HashMap<>();
  4. // using this distance Map from K as comparator in priority queue
  5. Map<Integer, Integer> dMapFromK = new HashMap<>();
  6. PriorityQueue<Integer> pq = new PriorityQueue<>((k1,k2) -> dMapFromK.get(k1) - dMapFromK.get(k2));
  7. HashSet<Integer> v = new HashSet<>();
  8. for(int i = 0; i < times.length; i++){
  9. int source = times[i][0];
  10. int dest = times[i][1];
  11. int dist = times[i][2];
  12. adjListWithDistance.putIfAbsent(source, new HashMap<>());
  13. adjListWithDistance.get(source).put(dest, dist);
  14. }
  15. dMapFromK.put(K, 0);
  16. pq.add(K);
  17. int res = 0;
  18. while(!pq.isEmpty()){
  19. int fromNode = pq.poll();
  20. if(v.contains(fromNode)) continue;
  21. v.add(fromNode);
  22. int curDist = dMapFromK.get(fromNode);
  23. res = curDist;
  24. if(!adjListWithDistance.containsKey(fromNode)) {
  25. continue;
  26. }
  27. for(Integer toNode: adjListWithDistance.get(fromNode).keySet()){
  28. if(v.contains(toNode)) continue;
  29. int toNodeDist = adjListWithDistance.get(fromNode).get(toNode);
  30. if(dMapFromK.containsKey(toNode) && dMapFromK.get(toNode) <= curDist + toNodeDist){
  31. continue;
  32. } else {
  33. if(!dMapFromK.containsKey(toNode)){
  34. dMapFromK.put(toNode, curDist + toNodeDist);
  35. pq.offer(toNode);
  36. } else{
  37. dMapFromK.put(toNode, curDist + toNodeDist);
  38. }
  39. }
  40. }
  41. }
  42. if(dMapFromK.keySet().size() != N)
  43. return -1;
  44. return Collections.max(dMapFromK.values());
  45. }
英文:

Full Disclosure - I am doing this exercise to solve this problem on Leetcode - https://leetcode.com/problems/network-delay-time/

I find that this code is not working for certain test cases. I have been trying to debug this for a few hours havent had any luck. Can anyone help the bug in this code.

  1. // this has a bug which I dont know how to find
  2. public int networkDelayTime(int[][] times, int N, int K) {
  3. Map&lt;Integer, Map&lt;Integer,Integer&gt;&gt; adjListWithDistance = new HashMap&lt;&gt;();
  4. // using this distance Map from K as comparator in priority queue
  5. Map&lt;Integer, Integer&gt; dMapFromK = new HashMap&lt;&gt;();
  6. PriorityQueue&lt;Integer&gt; pq = new PriorityQueue&lt;&gt;((k1,k2) -&gt; dMapFromK.get(k1) - dMapFromK.get(k2));
  7. HashSet&lt;Integer&gt; v = new HashSet&lt;&gt;();
  8. for(int i = 0; i &lt; times.length; i++){
  9. int source = times[i][0];
  10. int dest = times[i][1];
  11. int dist = times[i][2];
  12. adjListWithDistance.putIfAbsent(source, new HashMap&lt;&gt;());
  13. adjListWithDistance.get(source).put(dest, dist);
  14. //if(source == K){
  15. // dMapFromK.put(dest, dist);
  16. // pq.add(dest);
  17. // }
  18. }
  19. // distance from K to K is 0
  20. dMapFromK.put(K, 0);
  21. pq.add(K);
  22. // we have already added all nodes from K to PQ, so we dont need to process K again
  23. //v.add(K);
  24. //System.out.println(adjListWithDistance);
  25. //System.out.println(dMapFromK);
  26. int res = 0;
  27. while(!pq.isEmpty()){
  28. int fromNode = pq.poll();
  29. if(v.contains(fromNode)) continue;
  30. v.add(fromNode);
  31. int curDist = dMapFromK.get(fromNode);
  32. res = curDist;
  33. //System.out.println(&quot;current node - &quot; + fromNode);
  34. if(!adjListWithDistance.containsKey(fromNode)) {
  35. continue;
  36. }
  37. for(Integer toNode: adjListWithDistance.get(fromNode).keySet()){
  38. // BIG BUGGGGG, adding the below line is also causing a bug , not sure why
  39. if(v.contains(toNode)) continue;
  40. int toNodeDist = adjListWithDistance.get(fromNode).get(toNode);
  41. if(dMapFromK.containsKey(toNode) &amp;&amp; dMapFromK.get(toNode) &lt;= curDist + toNodeDist){
  42. continue;
  43. }else {
  44. if(!dMapFromK.containsKey(toNode)){
  45. dMapFromK.put(toNode, curDist + toNodeDist);
  46. // need to add map entry first before adding to priority queue else it throws an exception
  47. pq.offer(toNode);
  48. }else{
  49. dMapFromK.put(toNode, curDist + toNodeDist);
  50. }
  51. }
  52. }
  53. }
  54. System.out.println(adjListWithDistance);
  55. System.out.println(dMapFromK);
  56. if(dMapFromK.keySet().size() != N)
  57. return -1;
  58. //return res;
  59. return Collections.max(dMapFromK.values());
  60. }

TLDR: This implementation of Dijsktra is not correct and doesn't return the shortest path for certain nodes for certain test cases. I'm not sure why, and I need help debugging what mistake I am making.

答案1

得分: 1

以下是翻译后的内容:

这几乎是相同的,将会通过:

  1. public final class Solution {
  2. public static final int networkDelayTime(
  3. final int[][] times,
  4. int n,
  5. final int k
  6. ) {
  7. Map<Integer, Map<Integer, Integer>> graph = new HashMap<>();
  8. for (final int[] node : times) {
  9. graph.putIfAbsent(node[0], new HashMap<>());
  10. graph.get(node[0]).put(node[1], node[2]);
  11. }
  12. Queue<int[]> queue = new PriorityQueue<>((a, b) -> (a[0] - b[0]));
  13. queue.add(new int[] {0, k});
  14. boolean[] visited = new boolean[n + 1];
  15. int total = 0;
  16. while (!queue.isEmpty()) {
  17. int[] curr = queue.remove();
  18. int currNode = curr[1];
  19. int currTime = curr[0];
  20. if (visited[currNode]) {
  21. continue;
  22. }
  23. visited[currNode] = true;
  24. total = currTime;
  25. n--;
  26. if (graph.containsKey(currNode)) {
  27. for (final int next : graph.get(currNode).keySet()) {
  28. queue.add(new int[] {currTime + graph.get(currNode).get(next), next});
  29. }
  30. }
  31. }
  32. return n == 0 ? total : -1;
  33. }
  34. }

以下是使用堆的 Python 版本,如果你有兴趣的话:

  1. from typing import List
  2. import heapq
  3. from collections import defaultdict
  4. class Solution:
  5. def networkDelayTime(self, times: List[List[int]], n, k) -> int:
  6. queue = [(0, k)]
  7. graph = collections.defaultdict(list)
  8. memo = {}
  9. for u_node, v_node, time in times:
  10. graph[u_node].append((v_node, time))
  11. while queue:
  12. time, node = heapq.heappop(queue)
  13. if node not in memo:
  14. memo[node] = time
  15. for v_node, v_time in graph[node]:
  16. heapq.heappush(queue, (time + v_time, v_node))
  17. return max(memo.values()) if len(memo) == n else -1

在 C++ 中,我们将使用快速整数类型:

  1. // 以下部分可能会微不足道地提高执行时间;
  2. // 可以删除;
  3. static const auto __optimize__ = []() {
  4. std::ios::sync_with_stdio(false);
  5. std::cin.tie(NULL);
  6. std::cout.tie(NULL);
  7. return 0;
  8. }();
  9. // 大多数头文件已经包含;
  10. // 可以删除;
  11. #include <cstdint>
  12. #include <vector>
  13. #include <algorithm>
  14. #define MAX INT_MAX
  15. using ValueType = std::uint_fast16_t;
  16. static const struct Solution {
  17. static const int networkDelayTime(
  18. const std::vector<vector<int>>& times,
  19. int n,
  20. const int k
  21. ) {
  22. std::vector<ValueType> distances(n + 1, MAX);
  23. distances[k] = 0;
  24. for (ValueType index = 0; index < n; index++) {
  25. for (const auto& time : times) {
  26. const ValueType u_node = time[0];
  27. const ValueType v_node = time[1];
  28. const ValueType uv_weight = time[2];
  29. if (distances[u_node] != MAX && distances[v_node] > distances[u_node] + uv_weight) {
  30. distances[v_node] = distances[u_node] + uv_weight;
  31. }
  32. }
  33. }
  34. ValueType total_time = 0;
  35. for (auto index = 1; index <= n; index++) {
  36. total_time = std::max(total_time, distances[index]);
  37. }
  38. return total_time == MAX ? -1 : total_time;
  39. }
  40. };
英文:

This is almost the same, will pass through:

  1. public final class Solution {
  2. public static final int networkDelayTime(
  3. final int[][] times,
  4. int n,
  5. final int k
  6. ) {
  7. Map&lt;Integer, Map&lt;Integer, Integer&gt;&gt; graph = new HashMap&lt;&gt;();
  8. for (final int[] node : times) {
  9. graph.putIfAbsent(node[0], new HashMap&lt;&gt;());
  10. graph.get(node[0]).put(node[1], node[2]);
  11. }
  12. Queue&lt;int[]&gt; queue = new PriorityQueue&lt;&gt;((a, b) -&gt; (a[0] - b[0]));
  13. queue.add(new int[] {0, k});
  14. boolean[] visited = new boolean[n + 1];
  15. int total = 0;
  16. while (!queue.isEmpty()) {
  17. int[] curr = queue.remove();
  18. int currNode = curr[1];
  19. int currTime = curr[0];
  20. if (visited[currNode]) {
  21. continue;
  22. }
  23. visited[currNode] = true;
  24. total = currTime;
  25. n--;
  26. if (graph.containsKey(currNode)) {
  27. for (final int next : graph.get(currNode).keySet()) {
  28. queue.add(new int[] {currTime + graph.get(currNode).get(next), next});
  29. }
  30. }
  31. }
  32. return n == 0 ? total : -1;
  33. }
  34. }

Here is a Python version using heap, if you'd be interested:

  1. from typing import List
  2. import heapq
  3. from collections import defaultdict
  4. class Solution:
  5. def networkDelayTime(self, times: List[List[int]], n, k) -&gt; int:
  6. queue = [(0, k)]
  7. graph = collections.defaultdict(list)
  8. memo = {}
  9. for u_node, v_node, time in times:
  10. graph[u_node].append((v_node, time))
  11. while queue:
  12. time, node = heapq.heappop(queue)
  13. if node not in memo:
  14. memo[node] = time
  15. for v_node, v_time in graph[node]:
  16. heapq.heappush(queue, (time + v_time, v_node))
  17. return max(memo.values()) if len(memo) == n else -1

In C++, we'd just use a fast integer type:

  1. // The following block might trivially improve the exec time;
  2. // Can be removed;
  3. static const auto __optimize__ = []() {
  4. std::ios::sync_with_stdio(false);
  5. std::cin.tie(NULL);
  6. std::cout.tie(NULL);
  7. return 0;
  8. }();
  9. // Most of headers are already included;
  10. // Can be removed;
  11. #include &lt;cstdint&gt;
  12. #include &lt;vector&gt;
  13. #include &lt;algorithm&gt;
  14. #define MAX INT_MAX
  15. using ValueType = std::uint_fast16_t;
  16. static const struct Solution {
  17. static const int networkDelayTime(
  18. const std::vector&lt;vector&lt;int&gt;&gt;&amp; times,
  19. int n,
  20. const int k
  21. ) {
  22. std::vector&lt;ValueType&gt; distances(n + 1, MAX);
  23. distances[k] = 0;
  24. for (ValueType index = 0; index &lt; n; index++) {
  25. for (const auto&amp; time : times) {
  26. const ValueType u_node = time[0];
  27. const ValueType v_node = time[1];
  28. const ValueType uv_weight = time[2];
  29. if (distances[u_node] != MAX &amp;&amp; distances[v_node] &gt; distances[u_node] + uv_weight) {
  30. distances[v_node] = distances[u_node] + uv_weight;
  31. }
  32. }
  33. }
  34. ValueType total_time = 0;
  35. for (auto index = 1; index &lt;= n; index++) {
  36. total_time = std::max(total_time, distances[index]);
  37. }
  38. return total_time == MAX ? -1 : total_time;
  39. }
  40. };

References

  • For additional details, please see the Discussion Board where you can find plenty of well-explained accepted solutions with a variety of languages including low-complexity algorithms and asymptotic runtime/memory analysis<sup>1, 2</sup>.

huangapple
  • 本文由 发表于 2020年8月15日 07:22:41
  • 转载请务必保留本文链接:https://go.coder-hub.com/63421093.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定