递归任务的实现是否正确?

huangapple go评论67阅读模式
英文:

is the implementation of the Recursive Task below correct?

问题

以下是您要翻译的代码部分:

我开始理解递归任务和递归操作的实现根据我的理解和一些Java文档我编写了以下代码来将数组中的所有数字相加

我需要帮助来纠正这个代码并帮助我指出我哪里错了

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolTest {

    public static void main(String[] args) {

        ForkJoinPool pool = new ForkJoinPool(4);
        long[] numbers = {1,2,3,4,5,6,7,8,9};
        AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length -1 );
        ForkJoinTask<Long> submit = pool.submit(newTask);
        System.out.println(submit.join());
        
    }
}

class AdditionTask extends RecursiveTask<Long> {

    long[] numbers;
    int start;
    int end;

    public AdditionTask(long[] numbers, int start, int end) {
        this.numbers = numbers;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {

        if ((end - start) > 2) {

            int length = numbers.length;
            int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;
            AdditionTask leftSide = new AdditionTask(numbers, 0, mid);

            leftSide.fork();

            AdditionTask rightSide = new AdditionTask(numbers, mid+1, length-1);
            return rightSide.compute() + leftSide.join();

        } else {
            return numbers[0] + numbers[1];
        }
    }
}

新代码 [已修复]
这是我修复的代码,似乎在处理小数组时工作正常。在下面的示例中,数组大小为10000,但总和是错误的。为什么会计算出错误的总和?

import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolTest {

    public static void main(String[] args) {

        Random r = new Random();
        int low = 10000;
        int high = 100000;

        int size = 100000;

        long[] numbers = new long[size];
        int sum = 0;
        for (int i = 0; i < size; i++) {
            int n = r.nextInt(high - low) + low;
            numbers[i] = n;
            sum += numbers[i];
        }

        long s = System.currentTimeMillis();
        ForkJoinPool pool = new ForkJoinPool(1);
        AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length-1);
        ForkJoinTask<Long> submit = pool.submit(newTask);
        System.out.println("Expected Answer: " + sum + ", Actual: " + submit.join());
        long e = System.currentTimeMillis();
        System.out.println("Total time taken: " + (e - s) + " ms in parallel Operation");

        long s2 = System.currentTimeMillis();
        System.out.println("Started: " + s2);

        int manualSum = 0;
        for (long number : numbers) {
            manualSum += number;
        }

        System.out.println("Expected Answer: " + sum + ", Actual: " + manualSum);
        long e2 = System.currentTimeMillis();
        System.out.println("Ended: " + e2);
        System.out.println("Total time taken: " + (e2 - s2) + " ms in sequential Operation");
    }
}

class AdditionTask extends RecursiveTask<Long> {

    long[] numbers;
    int start;
    int end;

    public AdditionTask(long[] numbers, int start, int end) {
        this.numbers = numbers;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {

        int length = (start == 0) ? end +1 : (end - (start - 1));

        if (length > 2) {

            int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;

            AdditionTask leftSide = new AdditionTask(numbers, start, (start+mid));
            leftSide.fork();

            AdditionTask rightSide = new AdditionTask(numbers, (start+mid)+1, end);

            Long rightSideLong = rightSide.compute();

            Long leftSideLong = leftSide.join();
            Long total = rightSideLong + leftSideLong;

            return total;

        } else {

            if (start == end) {
                return numbers[start];
            }
            return numbers[start] + numbers[end];

        }
    }
}

希望这有助于您理解和修复代码中的问题。

英文:

I am beginning to understand the implementation of the Recursive Task and Recursive Actions. Based on my understanding and some java documentation, I came up with the below code to add up all the numbers in an array.

I need help in correcting this and help me point out where have I gone wrong please.

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
public class ForkJoinPoolTest {
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool(4);
long[] numbers = {1,2,3,4,5,6,7,8,9};
AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length -1 );
ForkJoinTask<Long> submit = pool.submit(newTask);
System.out.println(submit.join());
}
}
class AdditionTask extends RecursiveTask<Long> {
long[] numbers;
int start;
int end;
public AdditionTask(long[] numbers, int start, int end) {
this.numbers = numbers;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
if ((end - start) > 2) {
int length = numbers.length;
int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;
AdditionTask leftSide = new AdditionTask(numbers, 0, mid);
leftSide.fork();
AdditionTask rightSide = new AdditionTask(numbers, mid+1, length-1);
return rightSide.compute() + leftSide.join();
} else {
return numbers[0] + numbers[1];
}
}
}

New Code [Fixed]
This is the code I fixed and seems to be working well with only small arrays. In the below example the array size is 10000 and the sum is wrong. Why does it calculate the wrong sum?

import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
public class ForkJoinPoolTest {
public static void main(String[] args) {
Random r = new Random();
int low = 10000;
int high = 100000;
int size = 100000;
long[] numbers = new long[size];
int sum = 0;
for (int i = 0; i < size; i++) {
int n = r.nextInt(high - low) + low;
numbers[i] = n;
sum += numbers[i];
}
long s = System.currentTimeMillis();
ForkJoinPool pool = new ForkJoinPool(1);
AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length-1);
ForkJoinTask<Long> submit = pool.submit(newTask);
System.out.println("Expected Answer: " + sum + ", Actual: " + submit.join());
long e = System.currentTimeMillis();
System.out.println("Total time taken: " + (e - s) + " ms in parallel Operation");
long s2 = System.currentTimeMillis();
System.out.println("Started: " + s2);
int manualSum = 0;
for (long number : numbers) {
manualSum += number;
}
System.out.println("Expected Answer: " + sum + ", Actual: " + manualSum);
long e2 = System.currentTimeMillis();
System.out.println("Ended: " + e2);
System.out.println("Total time taken: " + (e2 - s2) + " ms in sequential Operation");
}
}
class AdditionTask extends RecursiveTask<Long> {
long[] numbers;
int start;
int end;
public AdditionTask(long[] numbers, int start, int end) {
this.numbers = numbers;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = (start == 0) ? end +1 : (end - (start - 1));
if (length > 2) {
int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;
AdditionTask leftSide = new AdditionTask(numbers, start, (start+mid));
leftSide.fork();
AdditionTask rightSide = new AdditionTask(numbers, (start+mid)+1, end);
Long rightSideLong = rightSide.compute();
Long leftSideLong = leftSide.join();
Long total = rightSideLong + leftSideLong;
return total;
} else {
if (start == end) {
return numbers[start];
}
return numbers[start] + numbers[end];
}
}
}

答案1

得分: 2

你的并行计算的第二个版本是正确的。但是,你代码中的两个非并行计算是有问题的,因为它们在求和时使用了 int,对于大数组来说会溢出。当你修复它们,也使用 long 时,它们将产生与你的并行计算相同的结果。

不过,还有一些可以改进的地方。首先,你应该摆脱那些条件语句:

int length = (start == 0) ? end + 1 : (end - (start - 1));

int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;

它们与更简单的方式相比没有提供任何好处:

int length = end - (start - 1); // 或者 end - start + 1

int mid = length / 2;

然后,高效的并行处理不应该尽可能地分解任务,而应该考虑实际可实现的并行性。你可以使用 getSurplusQueuedTaskCount() 来实现这一点:

@Override
protected Long compute() {
    int length = end - (start - 1);
    // 仅在可能从并行处理中获益时拆分任务
    if (length > 2 && getSurplusQueuedTaskCount() < 2) {
        int mid = length / 2;

        AdditionTask leftSide = new AdditionTask(numbers, start, (start + mid));
        leftSide.fork();

        AdditionTask rightSide = new AdditionTask(numbers, (start + mid) + 1, end);

        Long rightSideLong = rightSide.compute();

        // 如果没有工作线程接手任务,则在本线程中执行
        Long leftSideLong = leftSide.tryUnfork() ? leftSide.compute() : leftSide.join();
        Long total = rightSideLong + leftSideLong;

        return total;
    } else { // 顺序执行
        long sum = 0;
        for (int ix = start; ix <= end; ix++) sum += numbers[ix];
        return sum;
    }
}
英文:

The second version of your parallel calculation is correct. But both non-parallel computations in your code are broken as they use int for their sum, which will overflow for large arrays. When you fix them, to also use long, they will produce the same result as your parallel computation.

Still, there are some things to improve. First, you should get rid of those conditionals:

int length = (start == 0) ? end +1 : (end - (start - 1));

and

int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;

They provide no benefit over the simpler

int length = end - (start - 1); // or end - start + 1

and

int mid = length / 2;

Then, an efficient parallel processing should not decompose as much as possible, but incorporate the actual achievable parallelism. You can use getSurplusQueuedTaskCount() for that

@Override
protected Long compute() {
    int length = end - (start - 1);
    // only split when benefit from parallel processing is likely
    if (length &gt; 2 &amp;&amp; getSurplusQueuedTaskCount() &lt; 2) {
        int mid = length / 2;

        AdditionTask leftSide = new AdditionTask(numbers, start, (start+mid));
        leftSide.fork();

        AdditionTask rightSide = new AdditionTask(numbers, (start+mid)+1, end);

        Long rightSideLong = rightSide.compute();

        // do in this thread if no worker thread has picked it up yet
        Long leftSideLong = leftSide.tryUnfork()? leftSide.compute(): leftSide.join();
        Long total = rightSideLong + leftSideLong;

        return total;
    } else { // do sequential
        long sum = 0;
        for(int ix = start; ix &lt;= end; ix++) sum += numbers[ix];
        return sum;
    }
}

huangapple
  • 本文由 发表于 2020年10月16日 02:46:34
  • 转载请务必保留本文链接:https://go.coder-hub.com/64377931.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定