/home/caleb/ASDV-Java/Semester 4/Assignments/MP2_CalebFontenot/src/main/java/com/calebfontenot/mp2_calebfontenot/Matrices.java
/*
 * Click nbfs://nbhost/SystemFileSystem/Templates/Licenses/license-default.txt to change this license
 * Click nbfs://nbhost/SystemFileSystem/Templates/Classes/Class.java to edit this template
 */
package com.calebfontenot.mp2_calebfontenot;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

/**
 *
 * @author caleb
 */
public class Matrices {

    public static void main(String[] args) {
        Matrices.multiplyParallel();
    }

    public static ArrayList<ArrayList<BigInteger>> createRandomMatrix(int rows, int columns) {
        ArrayList<ArrayList<BigInteger>> matrix = new ArrayList<ArrayList<BigInteger>>(rows);
        for (int i = 0; i < rows; ++i) {
            ArrayList<BigInteger> row = new ArrayList<BigInteger>();
            for (int j = 0; j < columns; ++j) {
                row.add(new BigInteger(Integer.toString(1 + (int) (Math.random() * 9))));
            }
            matrix.add(row);
        }
        return matrix;
    }

    public static void multiplyParallel() {
        /*
        ArrayList<ArrayList<BigInteger>> A = new ArrayList<>();
        ArrayList<BigInteger> row1 = new ArrayList<>();
        row1.add(new BigInteger("6"));
        row1.add(new BigInteger("1"));
        ArrayList<BigInteger> row2 = new ArrayList<>();
        row2.add(new BigInteger("7"));
        row2.add(new BigInteger("2"));
        A.add(row1);
        A.add(row2);
        
        ArrayList<ArrayList<BigInteger>> B = new ArrayList<>();
        row1 = new ArrayList<BigInteger>();
        row1.add(new BigInteger("4"));
        row1.add(new BigInteger("1"));
        row1.add(new BigInteger("2"));
        row2 = new ArrayList<BigInteger>();
        row2.add(new BigInteger("1"));
        row2.add(new BigInteger("9"));
        row2.add(new BigInteger("5"));
        B.add(row1);
        B.add(row2);
        */
        ArrayList<ArrayList<BigInteger>> A = Matrices.createRandomMatrix(2, 2);
        ArrayList<ArrayList<BigInteger>> B = Matrices.createRandomMatrix(2, 3);

        RecursiveTask<ArrayList<ArrayList<BigInteger>>> rt
                = new Matrices.MatricesMultiplication(0, A.size() - 1, A, B);
        ForkJoinPool pool = new ForkJoinPool();
        ArrayList<ArrayList<BigInteger>> mul = pool.invoke(rt);

        System.out.println("MATRIX A");
        printMatrix(A);
        System.out.println("\nMATRIX B");
        printMatrix(B);
        System.out.println("\nMATRIX AxB");
        printMatrix(mul);
    }

    private static void printMatrix(ArrayList<ArrayList<BigInteger>> matrix) {
        for (int i = 0; i < matrix.size(); ++i) {
            System.out.println(matrix.get(i));
        }
    }

    static class MatricesMultiplication extends RecursiveTask<ArrayList<ArrayList<BigInteger>>> {

        ArrayList<ArrayList<BigInteger>> A;
        ArrayList<ArrayList<BigInteger>> B;
        ArrayList<ArrayList<BigInteger>> AxB;
        final int HOW_MANY_ROWS_IN_PARALLEL = 3;// threshold
        int startIndex;
        int endIndex;

        public MatricesMultiplication(int startIndex, int endIndex, ArrayList<ArrayList<BigInteger>> A, ArrayList<ArrayList<BigInteger>> B) {
            this.startIndex = startIndex;
            this.endIndex = endIndex;
            this.A = A;
            this.B = B;

        }

        @Override
        protected ArrayList<ArrayList<BigInteger>> compute() {
           // Base case: if the number of rows in A is less than or equal to the threshold,
    // perform matrix multiplication sequentially
    if (endIndex - startIndex + 1 <= HOW_MANY_ROWS_IN_PARALLEL) {
        ArrayList<ArrayList<BigInteger>> result = new ArrayList<>();
        for (int i = startIndex; i <= endIndex; i++) {
            ArrayList<BigInteger> rowResult = new ArrayList<>();
            for (int j = 0; j < B.get(0).size(); j++) {
                BigInteger sum = BigInteger.ZERO;
                for (int k = 0; k < A.get(0).size(); k++) {
                    sum = sum.add(A.get(i).get(k).multiply(B.get(k).get(j)));
                }
                rowResult.add(sum);
            }
            result.add(rowResult);
        }
        return result;
    } else {
        // Split the task into smaller subtasks
        int middle = (startIndex + endIndex) / 2;
        MatricesMultiplication leftTask = new MatricesMultiplication(startIndex, middle, A, B);
        MatricesMultiplication rightTask = new MatricesMultiplication(middle + 1, endIndex, A, B);

        // Fork the subtasks
        leftTask.fork();
        ArrayList<ArrayList<BigInteger>> rightResult = rightTask.compute();

        // Join the results of the subtasks
        ArrayList<ArrayList<BigInteger>> leftResult = leftTask.join();

        // Merge the results
        ArrayList<ArrayList<BigInteger>> result = new ArrayList<>();
        for (int i = 0; i < leftResult.size(); i++) {
            result.add(new ArrayList<>(leftResult.get(i)));
        }
        for (int i = 0; i < rightResult.size(); i++) {
            result.add(new ArrayList<>(rightResult.get(i)));
        }

        return result;
    }
        }
    }

}