Trying to implement Dijsktra in Java using priority queue and using hash map for decrease key

Trying to implement Dijsktra in Java using priority queue and using hash map for decrease key


  1. // this has a bug which I dont know how to find
  2. public int networkDelayTime(int[][] times, int N, int K) {
  3. Map<Integer, Map<Integer,Integer>> adjListWithDistance = new HashMap<>();
  4. // using this distance Map from K as comparator in priority queue
  5. Map<Integer, Integer> dMapFromK = new HashMap<>();
  6. PriorityQueue<Integer> pq = new PriorityQueue<>((k1,k2) -> dMapFromK.get(k1) - dMapFromK.get(k2));
  7. HashSet<Integer> v = new HashSet<>();
  8. for(int i = 0; i < times.length; i++){
  9. int source = times[i][0];
  10. int dest = times[i][1];
  11. int dist = times[i][2];
  12. adjListWithDistance.putIfAbsent(source, new HashMap<>());
  13. adjListWithDistance.get(source).put(dest, dist);
  14. }
  15. dMapFromK.put(K, 0);
  16. pq.add(K);
  17. int res = 0;
  18. while(!pq.isEmpty()){
  19. int fromNode = pq.poll();
  20. if(v.contains(fromNode)) continue;
  21. v.add(fromNode);
  22. int curDist = dMapFromK.get(fromNode);
  23. res = curDist;
  24. if(!adjListWithDistance.containsKey(fromNode)) {
  25. continue;
  26. }
  27. for(Integer toNode: adjListWithDistance.get(fromNode).keySet()){
  28. if(v.contains(toNode)) continue;
  29. int toNodeDist = adjListWithDistance.get(fromNode).get(toNode);
  30. if(dMapFromK.containsKey(toNode) && dMapFromK.get(toNode) <= curDist + toNodeDist){
  31. continue;
  32. } else {
  33. if(!dMapFromK.containsKey(toNode)){
  34. dMapFromK.put(toNode, curDist + toNodeDist);
  35. pq.offer(toNode);
  36. } else{
  37. dMapFromK.put(toNode, curDist + toNodeDist);
  38. }
  39. }
  40. }
  41. }
  42. if(dMapFromK.keySet().size() != N)
  43. return -1;
  44. return Collections.max(dMapFromK.values());
  45. }

Full Disclosure - I am doing this exercise to solve this problem on Leetcode -

I find that this code is not working for certain test cases. I have been trying to debug this for a few hours havent had any luck. Can anyone help the bug in this code.

  1. // this has a bug which I dont know how to find
  2. public int networkDelayTime(int[][] times, int N, int K) {
  3. Map&lt;Integer, Map&lt;Integer,Integer&gt;&gt; adjListWithDistance = new HashMap&lt;&gt;();
  4. // using this distance Map from K as comparator in priority queue
  5. Map&lt;Integer, Integer&gt; dMapFromK = new HashMap&lt;&gt;();
  6. PriorityQueue&lt;Integer&gt; pq = new PriorityQueue&lt;&gt;((k1,k2) -&gt; dMapFromK.get(k1) - dMapFromK.get(k2));
  7. HashSet&lt;Integer&gt; v = new HashSet&lt;&gt;();
  8. for(int i = 0; i &lt; times.length; i++){
  9. int source = times[i][0];
  10. int dest = times[i][1];
  11. int dist = times[i][2];
  12. adjListWithDistance.putIfAbsent(source, new HashMap&lt;&gt;());
  13. adjListWithDistance.get(source).put(dest, dist);
  14. //if(source == K){
  15. // dMapFromK.put(dest, dist);
  16. // pq.add(dest);
  17. // }
  18. }
  19. // distance from K to K is 0
  20. dMapFromK.put(K, 0);
  21. pq.add(K);
  22. // we have already added all nodes from K to PQ, so we dont need to process K again
  23. //v.add(K);
  24. //System.out.println(adjListWithDistance);
  25. //System.out.println(dMapFromK);
  26. int res = 0;
  27. while(!pq.isEmpty()){
  28. int fromNode = pq.poll();
  29. if(v.contains(fromNode)) continue;
  30. v.add(fromNode);
  31. int curDist = dMapFromK.get(fromNode);
  32. res = curDist;
  33. //System.out.println(&quot;current node - &quot; + fromNode);
  34. if(!adjListWithDistance.containsKey(fromNode)) {
  35. continue;
  36. }
  37. for(Integer toNode: adjListWithDistance.get(fromNode).keySet()){
  38. // BIG BUGGGGG, adding the below line is also causing a bug , not sure why
  39. if(v.contains(toNode)) continue;
  40. int toNodeDist = adjListWithDistance.get(fromNode).get(toNode);
  41. if(dMapFromK.containsKey(toNode) &amp;&amp; dMapFromK.get(toNode) &lt;= curDist + toNodeDist){
  42. continue;
  43. }else {
  44. if(!dMapFromK.containsKey(toNode)){
  45. dMapFromK.put(toNode, curDist + toNodeDist);
  46. // need to add map entry first before adding to priority queue else it throws an exception
  47. pq.offer(toNode);
  48. }else{
  49. dMapFromK.put(toNode, curDist + toNodeDist);
  50. }
  51. }
  52. }
  53. }
  54. System.out.println(adjListWithDistance);
  55. System.out.println(dMapFromK);
  56. if(dMapFromK.keySet().size() != N)
  57. return -1;
  58. //return res;
  59. return Collections.max(dMapFromK.values());
  60. }

TLDR: This implementation of Dijsktra is not correct and doesn't return the shortest path for certain nodes for certain test cases. I'm not sure why, and I need help debugging what mistake I am making.


得分: 1



  1. public final class Solution {
  2. public static final int networkDelayTime(
  3. final int[][] times,
  4. int n,
  5. final int k
  6. ) {
  7. Map<Integer, Map<Integer, Integer>> graph = new HashMap<>();
  8. for (final int[] node : times) {
  9. graph.putIfAbsent(node[0], new HashMap<>());
  10. graph.get(node[0]).put(node[1], node[2]);
  11. }
  12. Queue<int[]> queue = new PriorityQueue<>((a, b) -> (a[0] - b[0]));
  13. queue.add(new int[] {0, k});
  14. boolean[] visited = new boolean[n + 1];
  15. int total = 0;
  16. while (!queue.isEmpty()) {
  17. int[] curr = queue.remove();
  18. int currNode = curr[1];
  19. int currTime = curr[0];
  20. if (visited[currNode]) {
  21. continue;
  22. }
  23. visited[currNode] = true;
  24. total = currTime;
  25. n--;
  26. if (graph.containsKey(currNode)) {
  27. for (final int next : graph.get(currNode).keySet()) {
  28. queue.add(new int[] {currTime + graph.get(currNode).get(next), next});
  29. }
  30. }
  31. }
  32. return n == 0 ? total : -1;
  33. }
  34. }

以下是使用堆的 Python 版本,如果你有兴趣的话:

  1. from typing import List
  2. import heapq
  3. from collections import defaultdict
  4. class Solution:
  5. def networkDelayTime(self, times: List[List[int]], n, k) -> int:
  6. queue = [(0, k)]
  7. graph = collections.defaultdict(list)
  8. memo = {}
  9. for u_node, v_node, time in times:
  10. graph[u_node].append((v_node, time))
  11. while queue:
  12. time, node = heapq.heappop(queue)
  13. if node not in memo:
  14. memo[node] = time
  15. for v_node, v_time in graph[node]:
  16. heapq.heappush(queue, (time + v_time, v_node))
  17. return max(memo.values()) if len(memo) == n else -1

在 C++ 中,我们将使用快速整数类型:

  1. // 以下部分可能会微不足道地提高执行时间;
  2. // 可以删除;
  3. static const auto __optimize__ = []() {
  4. std::ios::sync_with_stdio(false);
  5. std::cin.tie(NULL);
  6. std::cout.tie(NULL);
  7. return 0;
  8. }();
  9. // 大多数头文件已经包含;
  10. // 可以删除;
  11. #include <cstdint>
  12. #include <vector>
  13. #include <algorithm>
  14. #define MAX INT_MAX
  15. using ValueType = std::uint_fast16_t;
  16. static const struct Solution {
  17. static const int networkDelayTime(
  18. const std::vector<vector<int>>& times,
  19. int n,
  20. const int k
  21. ) {
  22. std::vector<ValueType> distances(n + 1, MAX);
  23. distances[k] = 0;
  24. for (ValueType index = 0; index < n; index++) {
  25. for (const auto& time : times) {
  26. const ValueType u_node = time[0];
  27. const ValueType v_node = time[1];
  28. const ValueType uv_weight = time[2];
  29. if (distances[u_node] != MAX && distances[v_node] > distances[u_node] + uv_weight) {
  30. distances[v_node] = distances[u_node] + uv_weight;
  31. }
  32. }
  33. }
  34. ValueType total_time = 0;
  35. for (auto index = 1; index <= n; index++) {
  36. total_time = std::max(total_time, distances[index]);
  37. }
  38. return total_time == MAX ? -1 : total_time;
  39. }
  40. };

This is almost the same, will pass through:

  1. public final class Solution {
  2. public static final int networkDelayTime(
  3. final int[][] times,
  4. int n,
  5. final int k
  6. ) {
  7. Map&lt;Integer, Map&lt;Integer, Integer&gt;&gt; graph = new HashMap&lt;&gt;();
  8. for (final int[] node : times) {
  9. graph.putIfAbsent(node[0], new HashMap&lt;&gt;());
  10. graph.get(node[0]).put(node[1], node[2]);
  11. }
  12. Queue&lt;int[]&gt; queue = new PriorityQueue&lt;&gt;((a, b) -&gt; (a[0] - b[0]));
  13. queue.add(new int[] {0, k});
  14. boolean[] visited = new boolean[n + 1];
  15. int total = 0;
  16. while (!queue.isEmpty()) {
  17. int[] curr = queue.remove();
  18. int currNode = curr[1];
  19. int currTime = curr[0];
  20. if (visited[currNode]) {
  21. continue;
  22. }
  23. visited[currNode] = true;
  24. total = currTime;
  25. n--;
  26. if (graph.containsKey(currNode)) {
  27. for (final int next : graph.get(currNode).keySet()) {
  28. queue.add(new int[] {currTime + graph.get(currNode).get(next), next});
  29. }
  30. }
  31. }
  32. return n == 0 ? total : -1;
  33. }
  34. }

Here is a Python version using heap, if you'd be interested:

  1. from typing import List
  2. import heapq
  3. from collections import defaultdict
  4. class Solution:
  5. def networkDelayTime(self, times: List[List[int]], n, k) -&gt; int:
  6. queue = [(0, k)]
  7. graph = collections.defaultdict(list)
  8. memo = {}
  9. for u_node, v_node, time in times:
  10. graph[u_node].append((v_node, time))
  11. while queue:
  12. time, node = heapq.heappop(queue)
  13. if node not in memo:
  14. memo[node] = time
  15. for v_node, v_time in graph[node]:
  16. heapq.heappush(queue, (time + v_time, v_node))
  17. return max(memo.values()) if len(memo) == n else -1

In C++, we'd just use a fast integer type:

  1. // The following block might trivially improve the exec time;
  2. // Can be removed;
  3. static const auto __optimize__ = []() {
  4. std::ios::sync_with_stdio(false);
  5. std::cin.tie(NULL);
  6. std::cout.tie(NULL);
  7. return 0;
  8. }();
  9. // Most of headers are already included;
  10. // Can be removed;
  11. #include &lt;cstdint&gt;
  12. #include &lt;vector&gt;
  13. #include &lt;algorithm&gt;
  14. #define MAX INT_MAX
  15. using ValueType = std::uint_fast16_t;
  16. static const struct Solution {
  17. static const int networkDelayTime(
  18. const std::vector&lt;vector&lt;int&gt;&gt;&amp; times,
  19. int n,
  20. const int k
  21. ) {
  22. std::vector&lt;ValueType&gt; distances(n + 1, MAX);
  23. distances[k] = 0;
  24. for (ValueType index = 0; index &lt; n; index++) {
  25. for (const auto&amp; time : times) {
  26. const ValueType u_node = time[0];
  27. const ValueType v_node = time[1];
  28. const ValueType uv_weight = time[2];
  29. if (distances[u_node] != MAX &amp;&amp; distances[v_node] &gt; distances[u_node] + uv_weight) {
  30. distances[v_node] = distances[u_node] + uv_weight;
  31. }
  32. }
  33. }
  34. ValueType total_time = 0;
  35. for (auto index = 1; index &lt;= n; index++) {
  36. total_time = std::max(total_time, distances[index]);
  37. }
  38. return total_time == MAX ? -1 : total_time;
  39. }
  40. };


  • For additional details, please see the Discussion Board where you can find plenty of well-explained accepted solutions with a variety of languages including low-complexity algorithms and asymptotic runtime/memory analysis<sup>1, 2</sup>.

