如何让我的BFS算法运行更快?

huangapple go评论82阅读模式
英文:

How can I make my BFS algorithm run faster?

问题

以下是您提供的代码部分的翻译:

public void BFS(Point s, Point e) {
    /**
     * 北、南、东、西方向的坐标
     */
    int[] x = {0, 0, 1, -1};
    int[] y = {1, -1, 0, 0};

    LinkedList<Point> queue = new LinkedList<>();
    queue.add(s);

    /**
     * 存储每个点距离起点的距离的二维 int[][] 网格
     */
    int[][] dist = new int[numRow][numCol];
    for (int[] a : dist) {
        Arrays.fill(a, -1);
    }
    /**
     * "obstacles" 是一个包含网格上障碍物(坐标为 (x, y))的 Point 数组,
     * 标记为 -2,BFS 算法将避开这些障碍物。
     */
    for (Point ob : obstacles) {
        dist[ob.x][ob.y] = -2;
    }
    // 起点
    dist[s.x][s.y] = 0;

    /**
     * 从起点开始遍历 dist[][],将每个 [x][y] 坐标更改为距离 S 的 int 值。
     */
    while (!queue.isEmpty()) {
        Point p = queue.removeFirst();
        for (int i = 0; i < 4; i++) {
            int a = p.x + x[i];
            int b = p.y + y[i];
            if (a >= 0 && b >= 0 && a < numRow && b < numCol && dist[a][b] == -1) {
                dist[a][b] = 1 + dist[p.x][p.y];
                Point tempPoint = new Point(a, b);
                if (!queue.contains(tempPoint)) {
                    queue.add(tempPoint);
                }
            }
        }
    }

    /**
     * 反向查找所有 S 和 E 之间的最短路径点,并将每个点添加到名为 "validPaths" 的数组中。
     */
    queue.add(e);
    while (!queue.isEmpty()) {
        Point p = queue.removeFirst();

        // 检查点 p 上方、下方、左侧和右侧的网格空间
        for (int i = 0; i < 4; i++) {
            int curX = p.x + x[i];
            int curY = p.y + y[i];

            // 索引越界检查
            if (curX >= 0 && curY >= 0 && !(curX == start.x && curY == start.y) &&
                curX < numRow && curY < numCol) {
                if (dist[curX][curY] < dist[p.x][p.y] && dist[curX][curY] != -2) { // -2 代表障碍物
                    Point tempPoint = new Point(curX, curY);
                    if (!validPaths.contains(tempPoint)) {
                        validPaths.add(tempPoint);
                    }
                    if (!queue.contains(tempPoint)) {
                        queue.add(tempPoint);
                    }
                }
            }
        }
    }
}

请注意,这只是您提供的代码部分的翻译,不包括额外的解释或附加信息。如果您有任何关于这段代码的问题或需要进一步的帮助,请随时提问。

英文:

So I have a function that looks at a grid (2D array) and finds all the paths from the starting point to the end point. So far the algorithm works as intended and I get the values that I'm looking for.

The problem is that it takes forever. It can run over a 100 x 100 grid no problem, but once I get to a 10000 x 10000 grid, it'll take about 10 min to give back an answer, where I'm looking for maybe 1 min at most.

Here's what it looks like right now:

public void BFS(Point s, Point e){
/**
* North, South, East, West coordinates
*/
int[] x = {0,0,1,-1};
int[] y = {1,-1,0,0};
LinkedList&lt;Point&gt; queue = new LinkedList&lt;&gt;();
queue.add(s);
/**
* 2D int[][] grid that stores the distances of each point on the grid
* from the start
*/
int[][] dist = new int[numRow][numCol];
for(int[] a : dist){
Arrays.fill(a,-1);
}
/**
* &quot;obstacles&quot; is an array of Points that contain the (x, y) coordinates of obstacles on the grid
* designated as a -2, which the BFS algorithm will avoid.
*/
for(Point ob : obstacles){
dist[ob.x][ob.y] = -2;
}
// Start point
dist[s.x][s.y] = 0;
/**
* Loops over dist[][] from starting point, changing each [x][y] coordinate to the int
* value that is the distance from S.
*/
while(!queue.isEmpty()){                                        
Point p = queue.removeFirst();
for(int i = 0; i &lt; 4; i++){                             
int a = p.x + x[i];
int b = p.y + y[i];
if(a &gt;= 0 &amp;&amp; b &gt;= 0 &amp;&amp; a &lt; numRow &amp;&amp; b &lt; numCol &amp;&amp; dist[a][b] == -1){
dist[a][b] = 1 + dist[p.x][p.y];
Point tempPoint = new Point(a, b);
if(!queue.contains(tempPoint)){
queue.add(tempPoint);
}
}
}
}
/**
* Works backwards to find all shortest path points between S and E, and adds each
* point to an array called &quot;validPaths&quot;
*/
queue.add(e);
while(!queue.isEmpty()){
Point p = queue.removeFirst();
// Checks grid space (above, below, left, and right) from Point p
for(int i = 0; i &lt; 4; i++){
int curX = p.x + x[i];
int curY = p.y + y[i];
// Index Out of Bounds check
if(curX &gt;= 0 &amp;&amp; curY &gt;= 0 &amp;&amp; !(curX == start.x &amp;&amp; curY == start.y) &amp;&amp; curX &lt; numRow &amp;&amp; curY &lt; numCol){
if(dist[curX][curY] &lt; dist[p.x][p.y] &amp;&amp; dist[curX][curY] != -2){ // -2 is an obstacle
Point tempPoint = new Point(curX, curY);
if(!validPaths.contains(tempPoint)){
validPaths.add(tempPoint);
}
if(!queue.contains(tempPoint)){
queue.add(tempPoint);
}
}
}
}
}

So again, while it works, it's really slow. I'm trying to get a O(n + m), but I believe that it might be running in O(n^2).

Does anyone know any good ideas to make this faster?

答案1

得分: 2

一个明显的导致观察到的低效率的原因是比较 !validPaths.contains(tempPoint)!queue.contains(tempPoint),它们都是 O(n) 的操作。为了进行这些比较,您应该努力实现 O(1) 的比较,这可以通过使用特殊的数据结构,比如 哈希集合 或者简单的 位集合 来实现。

就目前而言,由于这些比较,您的实现显然是 O(n^2) 的,因为时间复杂度受到了影响。

英文:

An clear reason for the observed inefficiency are the comparisons !validPaths.contains(tempPoint) and !queue.contains(tempPoint) which are both O(n). To do these comparisons you should be striving for an O(1) comparison, which can be accomplished by using a special datastructure such as a hash-set or simply a bitset.

As it stands now, your implementation is clearly O(n^2) because of these comparisons.

huangapple
  • 本文由 发表于 2020年10月1日 03:38:54
  • 转载请务必保留本文链接:https://go.coder-hub.com/64144631.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定