Skip to content

6 RecursiveTask

Proyecto RecursiveTask

En este proyecto implementaremos una tarea recursiva que consiste en sumar todos los elementos de un array. Si el rango de elementos a sumar tiene menos de 10 valores, se calculará la suma recorriendo el rango de manera secuencial. Si el rango de elementos a sumar tiene 10 valores o más, seguiremos una de las siguientes dos estrategias:

  • Estrategia 1: Dividir el rango por la mitad y crear dos subtareas cada una de las cuales sume un subrango aplicando recursivamente el mismo algoritmo anterior.
  • Estrategia 2: Dividir el rango por la mitad y hacer que el primer subrango sea sumado secuencialmente y crear una subtarea para el segundo subrango aplicando recursivamente el mismo algoritmo anterior.
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;

class Main {

    public static void main(String[] args) throws InterruptedException, ExecutionException {
        new Main().start();
    }

    private void start() throws InterruptedException, ExecutionException {
        int size = 1000;
        int[] values = createArray(size);
        findSumSecuentially(values);
        findSumInParallel(values);
    }

    private void findSumSecuentially(int[] values) {
        long start = System.currentTimeMillis();
        long sum = 0;
        for (int value : values) {
            sum += value;
        }
        System.out.printf("Secuential sum done in %d millis with result %d\n", System.currentTimeMillis() - start, sum);
    }

    private void findSumInParallel(int[] values) throws InterruptedException, ExecutionException {
        long start = System.currentTimeMillis();
        SumTask sumTask = new SumTask(values, 0, values.length);
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        long max = forkJoinPool.invoke(sumTask);
        System.out.printf("Parallel sum done in %d millis with result %d\n", System.currentTimeMillis() - start, max);
        System.out.printf("Work steal count: %d\n", forkJoinPool.getStealCount());
        forkJoinPool.shutdown();
    }

    private int[] createArray(int size) {
        int[] values = new int[size];
        for (int i = 0; i < size; i++) {
            values[i] = ThreadLocalRandom.current().nextInt(100);
        }
        return values;
    }

}
import java.time.LocalTime;
import java.time.format.DateTimeFormatter;
import java.util.concurrent.RecursiveTask;

public class SumTask extends RecursiveTask<Long> {

    private final DateTimeFormatter dateTimeFormatter =
            DateTimeFormatter.ofPattern("HH:mm:ss:SSS");

    private final int[] values;
    private final int from;
    private final int to;

    SumTask(int[] values, int from, int to) {
        this.values = values;
        this.from = from;
        this.to = to;
    }

    @Override
    protected Long compute() {
        // If range is small enough update directly secuentially.
        if (to - from < 10) {
            return sum(values, from, to);
        } else {
            long sum = applyStrategy1(values, from, to);
//            long sum = applyStrategy2(values, from, to);
            return sum;
        }
    }

    private long applyStrategy1(int[] values, int from, int to) {
        int pivot = (to + from) / 2;
        System.out.printf("%s - %s - [%d,%d) split in [%d,%d) y [%d,%d)\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                from, to, from, pivot, pivot, to);
        SumTask subTask1 = new SumTask(values, from, pivot);
        SumTask subTask2 = new SumTask(values, pivot, to);
        subTask1.fork();
        subTask2.fork();
        long sum1 = subTask1.join();
        long sum2 = subTask2.join();
        long sum = sum1 + sum2;
        System.out.printf("%s - %s - [%d, %d) joined with %d\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                from, pivot, sum1);
        System.out.printf("%s - %s - [%d, %d) joined with %d\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                pivot, to, sum2);
        System.out.printf("%s - %s - [%d, %d) done with %d\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                from, to, sum);
        return sum;
    }

    private long applyStrategy2(int[] values, int from, int to) {
        int pivot = (to + from) / 2;
        System.out.printf("%s - %s - [%d,%d) split in [%d,%d) y [%d,%d)\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                from, to, from, pivot, pivot, to);
        SumTask subTask2 = new SumTask(values, pivot, to);
        subTask2.fork();
        long sum1 = sum(values, from, pivot);
        long sum2 = subTask2.join();
        long sum = sum1 + sum2;
        System.out.printf("%s - %s - [%d, %d) joined with %d\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                pivot, to, sum2);
        System.out.printf("%s - %s - [%d, %d) done with %d\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                from, to, sum);
        return sum;
    }

    private long sum(int[] values, int from, int to) {
        long sum = 0;
        for (int i = from; i < to; i++) {
            sum += values[i];
        }
        System.out.printf("%s - %s - [%d, %d] done secuentially with %d\n",
                Thread.currentThread().getName(),
                dateTimeFormatter.format(LocalTime.now()),
                from, to, sum);
        return sum;
    }

}