/*
 * $Id: MatrixMultMPP.java,v 1.5 2002/08/03 17:17:42 kredel Exp $
 */

//package edu.unima.ky.parallel.mpijava;

import java.io.IOException;
import mpp.Communicator;
import mpp.BlockCommunicator;
import mpp.StreamCommunicator;


/**
 * Matrix Multiplication.
 * sequential and parallel using MPP.
 * @author Akitoshi Yoshida
 * @author Heinz Kredel.
 */
public class MatrixMultMPP {

  protected Communicator mpi_comm_world;

  public MatrixMultMPP(Communicator comm) {
      mpi_comm_world = comm;
  }

/**
 * @param C two-dimensional double array.
 * @param A two-dimensional double array.
 * @param B two-dimensional double array.
 */
public void seqmult(double[][] C, double[][] A, double[][] B) {
      for (int i=0; i < A.length; i++) {
          for (int j=0; j < B[0].length; j++) {
              double c = 0.0;
              for (int k=0; k < B.length; k++) {
                  c += A[i][k] * B[k][j];
              }
              C[i][j] = c;
          }
      }
  }

/**
 * @param C two-dimensional double array.
 * @param A two-dimensional double array.
 * @param B two-dimensional double array.
 * @throws IOException.
 */
public void parmult(double[][] C, double[][] A, double[][] B) 
         throws IOException {

      boolean prnt = false;

      // initialize MPI Communicators
      int myid     = mpi_comm_world.rank() ;
      int numprocs = mpi_comm_world.size() ;
      if (true||prnt) {
         System.out.println("MatrixMultMPP on " + myid 
             + " of " + numprocs + " ") ;
      }
      int counts = A.length;
      int[] assign = new int[counts];
      for (int k=0; k < counts; k++) {
          assign[k] = k % numprocs;
      }

      /* transfer data */
      //int msgtag = 4711;
      if ( myid == 0 ) {
         if (true||prnt) {
            System.out.print("assign = ");
            vecprint(assign);
            //System.out.print("sending A ");
         }
         for (int k=0; k < counts; k++ ) {
             if ( myid != assign[k] ) {
                mpi_comm_world.send(A[k], 0, A[k].length, assign[k]);
             }
         }
      } else {
         for (int y = 0; y < counts; y++ ) {
             if ( myid == assign[y] ) {
                mpi_comm_world.recv(A[y], 0, A[y].length, 0);
             }
         }
      }
      if ( myid == 0 ) {
         if (prnt) {
            System.out.print("sending B ");
         }
      }
      for (int x=0; x < B.length; x++ ){
         mpi_comm_world.bcast(B[x], 0, B[x].length, 0);
      }

      for (int i=0; i < counts; i++) {
          if ( myid == assign[i] ) {
             for (int j=0; j < B[0].length; j++) {
                 double c = 0.0;
                 for (int k=0; k < B.length; k++) {
                     c += A[i][k] * B[k][j];
                 }
             C[i][j] = c;
             }
          }
      }

      if ( myid == 0 ) {
         if (prnt) {
            System.out.print("receiving C ");
         }
         for (int k=0; k < counts; k++ ) {
             if ( myid != assign[k] ) {
                mpi_comm_world.recv(C[k], 0, C[k].length, assign[k]);
             }
         }
      } else {
         for (int y = 0; y < counts; y++ ) {
             if ( myid == assign[y] ) {
                mpi_comm_world.send(C[y], 0, C[y].length, 0);
             }
         }
      }


  }


/**
 * @param C two-dimensional double array.
 * @param A two-dimensional double array.
 * @param B two-dimensional double array.
 */
public void seqdiff(double[][] C, double[][] A, double[][] B) {
      for (int i=0; i < C.length; i++) {
          for (int j=0; j < C[0].length; j++) {
              C[i][j] = A[i][j] - B[i][j];
          }
      }
  }

/**
 * @param C two-dimensional double array.
 * @param A two-dimensional double array.
 * @param B two-dimensional double array.
 * @param i row of A.
 * @param j column of B.
 */
public void dotmult(double[][] C, double[][] A, double[][] B, int i, int j) {
      double c = 0.0;
      for (int k=0; k < B.length; k++) {
          c += A[i][k] * B[k][j];
      }
      C[i][j] = c;
  }

/**
 * @param C two-dimensional double array.
 * @param A two-dimensional double array.
 * @param B two-dimensional double array.
 */
public void seq2mult(double[][] C, double[][] A, double[][] B) {
      for (int i=0; i < A.length; i++) {
          for (int j=0; j < B[0].length; j++) {
              dotmult(C,A,B,i,j);
          }
      }
  }

/**
 * @param n rows of result.
 * @param m columns of result.
 * @return A two-dimensional double array.
 */
public double[][] matgen(int n, int m) {
      double[][] A = new double[n][m];
      for (int i=0; i < n; i++) {
          for (int j=0; j < m; j++) {
              A[i][j] = Math.random();
          }
      }
      return A;
  }

/**
 * @param n rows of result.
 * @param m columns of result.
 * @return A two-dimensional double array.
 */
public double[][] matgen0(int n, int m) {
      double[][] A = new double[n][m];
      for (int i=0; i < n; i++) {
          for (int j=0; j < m; j++) {
              A[i][j] = 0.0;
          }
      }
      return A;
  }

/**
 * @param n rows of result.
 * @param m columns of result.
 * @return A two-dimensional double array. 
 */
public double[][] matgen1(int n, int m) {
      double[][] A = new double[n][m];
      for (int i=0; i < n; i++) {
          for (int j=0; j < m; j++) {
              if (i == j) A[i][j] = 1.0; else A[i][j]= 0.0;
          }
      }
      return A;
  }
   
/**
 * @param A two-dimensional double array.
 */
public void matprint(double[][] A) {
      for (int i=0; i < A.length; i++) {
          for (int j=0; j < A[0].length; j++) {
              System.out.print(A[i][j] + " ");
          }
          System.out.println();
      }
  }

/**
 * @param V vector to print.
 */
public void vecprint(int[] V) {
      for (int i=0; i < V.length; i++) {
          System.out.print(V[i] + " ");
      }
      System.out.println();
  }

/**
 * @param A two dimensional double array.
 * @return true if A is approximately zero.
 */
public boolean matcheck0(double[][] A) {
      double eps = Double.MIN_VALUE*1000.0;
      for (int i=0; i < A.length; i++) {
          for (int j=0; j < A[0].length; j++) {
              if ( Math.abs(A[i][j]) > eps) return false;
          }
      }
      return true;
  }

/**
 * @param args
 * @throws IOException
 */
public static void main(String[] args) throws IOException {

        int n = 300, m = 300;
        boolean prnt = false;

        try { n = Integer.parseInt(args[0]); }
        catch (Exception e) { }
        try { m = Integer.parseInt(args[1]); }
        catch (Exception e) { }
        try { prnt = Boolean.valueOf(args[2]).booleanValue(); }
        catch (Exception e) { }
 
        //prnt=true;

        //MPI.Init(args) ;
        Communicator comm = new BlockCommunicator();
        //Communicator comm = new StreamCommunicator();
        int myid     = comm.rank() ;


        MatrixMultMPP x = new MatrixMultMPP(comm);

        double[][] A;
        double[][] B;
        double[][] C;
        double[][] D;
        long tm =0; 

     if ( myid == 0 ) {
        A = x.matgen1(n,m);
        B = x.matgen(m,n);
        C = x.matgen0(n,n);
        D = x.matgen0(n,n);
     } else {
        A = new double[n][m];
        B = new double[m][n];
        C = new double[n][n];
        D = new double[n][n];
     }

     if ( myid == 0 ) {
        System.out.println("A = ");
        if (prnt) x.matprint(A);
        System.out.println("");

        System.out.println("B = ");
        if (prnt) x.matprint(B);
        System.out.println("");

        System.out.println("C = ");
        tm = System.currentTimeMillis();
        x.seqmult(C,A,B);
        tm = System.currentTimeMillis() - tm; 
        System.out.println(tm + "ms");
     }

        tm = System.currentTimeMillis();
        x.parmult(C,A,B);
        tm = System.currentTimeMillis() - tm; 

     if ( myid == 0 ) {
        System.out.println(tm + "ms");
        if (prnt) x.matprint(C);
        System.out.println("");

        System.out.println("D = ");
        x.seqdiff(D,B,C);
        if (prnt) x.matprint(D);
        System.out.println("");
        System.out.println("D is zero = " + x.matcheck0(D) );
        }

    // MPI.Finalize();
     comm.close();

    }

}
