001    /*
002     * $Id: MatrixMultMPP.java,v 1.5 2002/08/03 17:17:42 kredel Exp $
003     */
004    
005    //package edu.unima.ky.parallel.mpijava;
006    
007    import java.io.IOException;
008    import mpp.Communicator;
009    import mpp.BlockCommunicator;
010    import mpp.StreamCommunicator;
011    
012    
013    /**
014     * Matrix Multiplication.
015     * sequential and parallel using MPP.
016     * @author Akitoshi Yoshida
017     * @author Heinz Kredel.
018     */
019    public class MatrixMultMPP {
020    
021      protected Communicator mpi_comm_world;
022    
023      public MatrixMultMPP(Communicator comm) {
024          mpi_comm_world = comm;
025      }
026    
027    /**
028     * @param C two-dimensional double array.
029     * @param A two-dimensional double array.
030     * @param B two-dimensional double array.
031     */
032    public void seqmult(double[][] C, double[][] A, double[][] B) {
033          for (int i=0; i < A.length; i++) {
034              for (int j=0; j < B[0].length; j++) {
035                  double c = 0.0;
036                  for (int k=0; k < B.length; k++) {
037                      c += A[i][k] * B[k][j];
038                  }
039                  C[i][j] = c;
040              }
041          }
042      }
043    
044    /**
045     * @param C two-dimensional double array.
046     * @param A two-dimensional double array.
047     * @param B two-dimensional double array.
048     * @throws IOException.
049     */
050    public void parmult(double[][] C, double[][] A, double[][] B) 
051             throws IOException {
052    
053          boolean prnt = false;
054    
055          // initialize MPI Communicators
056          int myid     = mpi_comm_world.rank() ;
057          int numprocs = mpi_comm_world.size() ;
058          if (true||prnt) {
059             System.out.println("MatrixMultMPP on " + myid 
060                 + " of " + numprocs + " ") ;
061          }
062          int counts = A.length;
063          int[] assign = new int[counts];
064          for (int k=0; k < counts; k++) {
065              assign[k] = k % numprocs;
066          }
067    
068          /* transfer data */
069          //int msgtag = 4711;
070          if ( myid == 0 ) {
071             if (true||prnt) {
072                System.out.print("assign = ");
073                vecprint(assign);
074                //System.out.print("sending A ");
075             }
076             for (int k=0; k < counts; k++ ) {
077                 if ( myid != assign[k] ) {
078                    mpi_comm_world.send(A[k], 0, A[k].length, assign[k]);
079                 }
080             }
081          } else {
082             for (int y = 0; y < counts; y++ ) {
083                 if ( myid == assign[y] ) {
084                    mpi_comm_world.recv(A[y], 0, A[y].length, 0);
085                 }
086             }
087          }
088          if ( myid == 0 ) {
089             if (prnt) {
090                System.out.print("sending B ");
091             }
092          }
093          for (int x=0; x < B.length; x++ ){
094             mpi_comm_world.bcast(B[x], 0, B[x].length, 0);
095          }
096    
097          for (int i=0; i < counts; i++) {
098              if ( myid == assign[i] ) {
099                 for (int j=0; j < B[0].length; j++) {
100                     double c = 0.0;
101                     for (int k=0; k < B.length; k++) {
102                         c += A[i][k] * B[k][j];
103                     }
104                 C[i][j] = c;
105                 }
106              }
107          }
108    
109          if ( myid == 0 ) {
110             if (prnt) {
111                System.out.print("receiving C ");
112             }
113             for (int k=0; k < counts; k++ ) {
114                 if ( myid != assign[k] ) {
115                    mpi_comm_world.recv(C[k], 0, C[k].length, assign[k]);
116                 }
117             }
118          } else {
119             for (int y = 0; y < counts; y++ ) {
120                 if ( myid == assign[y] ) {
121                    mpi_comm_world.send(C[y], 0, C[y].length, 0);
122                 }
123             }
124          }
125    
126    
127      }
128    
129    
130    /**
131     * @param C two-dimensional double array.
132     * @param A two-dimensional double array.
133     * @param B two-dimensional double array.
134     */
135    public void seqdiff(double[][] C, double[][] A, double[][] B) {
136          for (int i=0; i < C.length; i++) {
137              for (int j=0; j < C[0].length; j++) {
138                  C[i][j] = A[i][j] - B[i][j];
139              }
140          }
141      }
142    
143    /**
144     * @param C two-dimensional double array.
145     * @param A two-dimensional double array.
146     * @param B two-dimensional double array.
147     * @param i row of A.
148     * @param j column of B.
149     */
150    public void dotmult(double[][] C, double[][] A, double[][] B, int i, int j) {
151          double c = 0.0;
152          for (int k=0; k < B.length; k++) {
153              c += A[i][k] * B[k][j];
154          }
155          C[i][j] = c;
156      }
157    
158    /**
159     * @param C two-dimensional double array.
160     * @param A two-dimensional double array.
161     * @param B two-dimensional double array.
162     */
163    public void seq2mult(double[][] C, double[][] A, double[][] B) {
164          for (int i=0; i < A.length; i++) {
165              for (int j=0; j < B[0].length; j++) {
166                  dotmult(C,A,B,i,j);
167              }
168          }
169      }
170    
171    /**
172     * @param n rows of result.
173     * @param m columns of result.
174     * @return A two-dimensional double array.
175     */
176    public double[][] matgen(int n, int m) {
177          double[][] A = new double[n][m];
178          for (int i=0; i < n; i++) {
179              for (int j=0; j < m; j++) {
180                  A[i][j] = Math.random();
181              }
182          }
183          return A;
184      }
185    
186    /**
187     * @param n rows of result.
188     * @param m columns of result.
189     * @return A two-dimensional double array.
190     */
191    public double[][] matgen0(int n, int m) {
192          double[][] A = new double[n][m];
193          for (int i=0; i < n; i++) {
194              for (int j=0; j < m; j++) {
195                  A[i][j] = 0.0;
196              }
197          }
198          return A;
199      }
200    
201    /**
202     * @param n rows of result.
203     * @param m columns of result.
204     * @return A two-dimensional double array. 
205     */
206    public double[][] matgen1(int n, int m) {
207          double[][] A = new double[n][m];
208          for (int i=0; i < n; i++) {
209              for (int j=0; j < m; j++) {
210                  if (i == j) A[i][j] = 1.0; else A[i][j]= 0.0;
211              }
212          }
213          return A;
214      }
215       
216    /**
217     * @param A two-dimensional double array.
218     */
219    public void matprint(double[][] A) {
220          for (int i=0; i < A.length; i++) {
221              for (int j=0; j < A[0].length; j++) {
222                  System.out.print(A[i][j] + " ");
223              }
224              System.out.println();
225          }
226      }
227    
228    /**
229     * @param V vector to print.
230     */
231    public void vecprint(int[] V) {
232          for (int i=0; i < V.length; i++) {
233              System.out.print(V[i] + " ");
234          }
235          System.out.println();
236      }
237    
238    /**
239     * @param A two dimensional double array.
240     * @return true if A is approximately zero.
241     */
242    public boolean matcheck0(double[][] A) {
243          double eps = Double.MIN_VALUE*1000.0;
244          for (int i=0; i < A.length; i++) {
245              for (int j=0; j < A[0].length; j++) {
246                  if ( Math.abs(A[i][j]) > eps) return false;
247              }
248          }
249          return true;
250      }
251    
252    /**
253     * @param args
254     * @throws IOException
255     */
256    public static void main(String[] args) throws IOException {
257    
258            int n = 300, m = 300;
259            boolean prnt = false;
260    
261            try { n = Integer.parseInt(args[0]); }
262            catch (Exception e) { }
263            try { m = Integer.parseInt(args[1]); }
264            catch (Exception e) { }
265            try { prnt = Boolean.valueOf(args[2]).booleanValue(); }
266            catch (Exception e) { }
267     
268            //prnt=true;
269    
270            //MPI.Init(args) ;
271            Communicator comm = new BlockCommunicator();
272            //Communicator comm = new StreamCommunicator();
273            int myid     = comm.rank() ;
274    
275    
276            MatrixMultMPP x = new MatrixMultMPP(comm);
277    
278            double[][] A;
279            double[][] B;
280            double[][] C;
281            double[][] D;
282            long tm =0; 
283    
284         if ( myid == 0 ) {
285            A = x.matgen1(n,m);
286            B = x.matgen(m,n);
287            C = x.matgen0(n,n);
288            D = x.matgen0(n,n);
289         } else {
290            A = new double[n][m];
291            B = new double[m][n];
292            C = new double[n][n];
293            D = new double[n][n];
294         }
295    
296         if ( myid == 0 ) {
297            System.out.println("A = ");
298            if (prnt) x.matprint(A);
299            System.out.println("");
300    
301            System.out.println("B = ");
302            if (prnt) x.matprint(B);
303            System.out.println("");
304    
305            System.out.println("C = ");
306            tm = System.currentTimeMillis();
307            x.seqmult(C,A,B);
308            tm = System.currentTimeMillis() - tm; 
309            System.out.println(tm + "ms");
310         }
311    
312            tm = System.currentTimeMillis();
313            x.parmult(C,A,B);
314            tm = System.currentTimeMillis() - tm; 
315    
316         if ( myid == 0 ) {
317            System.out.println(tm + "ms");
318            if (prnt) x.matprint(C);
319            System.out.println("");
320    
321            System.out.println("D = ");
322            x.seqdiff(D,B,C);
323            if (prnt) x.matprint(D);
324            System.out.println("");
325            System.out.println("D is zero = " + x.matcheck0(D) );
326            }
327    
328        // MPI.Finalize();
329         comm.close();
330    
331        }
332    
333    }