6 RecursiveTask¶
Proyecto RecursiveTask¶
En este proyecto implementaremos una tarea recursiva que consiste en sumar todos los elementos de un array. Si el rango de elementos a sumar tiene menos de 10 valores, se calculará la suma recorriendo el rango de manera secuencial. Si el rango de elementos a sumar tiene 10 valores o más, seguiremos una de las siguientes dos estrategias:
- Estrategia 1: Dividir el rango por la mitad y crear dos subtareas cada una de las cuales sume un subrango aplicando recursivamente el mismo algoritmo anterior.
- Estrategia 2: Dividir el rango por la mitad y hacer que el primer subrango sea sumado secuencialmente y crear una subtarea para el segundo subrango aplicando recursivamente el mismo algoritmo anterior.
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
class Main {
public static void main(String[] args) throws InterruptedException, ExecutionException {
new Main().start();
}
private void start() throws InterruptedException, ExecutionException {
int size = 1000;
int[] values = createArray(size);
findSumSecuentially(values);
findSumInParallel(values);
}
private void findSumSecuentially(int[] values) {
long start = System.currentTimeMillis();
long sum = 0;
for (int value : values) {
sum += value;
}
System.out.printf("Secuential sum done in %d millis with result %d\n", System.currentTimeMillis() - start, sum);
}
private void findSumInParallel(int[] values) throws InterruptedException, ExecutionException {
long start = System.currentTimeMillis();
SumTask sumTask = new SumTask(values, 0, values.length);
ForkJoinPool forkJoinPool = new ForkJoinPool();
long max = forkJoinPool.invoke(sumTask);
System.out.printf("Parallel sum done in %d millis with result %d\n", System.currentTimeMillis() - start, max);
System.out.printf("Work steal count: %d\n", forkJoinPool.getStealCount());
forkJoinPool.shutdown();
}
private int[] createArray(int size) {
int[] values = new int[size];
for (int i = 0; i < size; i++) {
values[i] = ThreadLocalRandom.current().nextInt(100);
}
return values;
}
}
import java.time.LocalTime;
import java.time.format.DateTimeFormatter;
import java.util.concurrent.RecursiveTask;
public class SumTask extends RecursiveTask<Long> {
private final DateTimeFormatter dateTimeFormatter =
DateTimeFormatter.ofPattern("HH:mm:ss:SSS");
private final int[] values;
private final int from;
private final int to;
SumTask(int[] values, int from, int to) {
this.values = values;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
// If range is small enough update directly secuentially.
if (to - from < 10) {
return sum(values, from, to);
} else {
long sum = applyStrategy1(values, from, to);
// long sum = applyStrategy2(values, from, to);
return sum;
}
}
private long applyStrategy1(int[] values, int from, int to) {
int pivot = (to + from) / 2;
System.out.printf("%s - %s - [%d,%d) split in [%d,%d) y [%d,%d)\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
from, to, from, pivot, pivot, to);
SumTask subTask1 = new SumTask(values, from, pivot);
SumTask subTask2 = new SumTask(values, pivot, to);
subTask1.fork();
subTask2.fork();
long sum1 = subTask1.join();
long sum2 = subTask2.join();
long sum = sum1 + sum2;
System.out.printf("%s - %s - [%d, %d) joined with %d\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
from, pivot, sum1);
System.out.printf("%s - %s - [%d, %d) joined with %d\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
pivot, to, sum2);
System.out.printf("%s - %s - [%d, %d) done with %d\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
from, to, sum);
return sum;
}
private long applyStrategy2(int[] values, int from, int to) {
int pivot = (to + from) / 2;
System.out.printf("%s - %s - [%d,%d) split in [%d,%d) y [%d,%d)\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
from, to, from, pivot, pivot, to);
SumTask subTask2 = new SumTask(values, pivot, to);
subTask2.fork();
long sum1 = sum(values, from, pivot);
long sum2 = subTask2.join();
long sum = sum1 + sum2;
System.out.printf("%s - %s - [%d, %d) joined with %d\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
pivot, to, sum2);
System.out.printf("%s - %s - [%d, %d) done with %d\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
from, to, sum);
return sum;
}
private long sum(int[] values, int from, int to) {
long sum = 0;
for (int i = from; i < to; i++) {
sum += values[i];
}
System.out.printf("%s - %s - [%d, %d] done secuentially with %d\n",
Thread.currentThread().getName(),
dateTimeFormatter.format(LocalTime.now()),
from, to, sum);
return sum;
}
}