/*****************************************************************
 * mmult.c -- multithreaded matrix multiplication example
 *
 * Author: Mark Hays <hays@math.arizona.edu>
 */

#include "pt.h"

#include <stdio.h>
#include <stdlib.h>

#define NTHREADS 4

/* global data area */
int mwork,nwork,pwork;          /* dimensions of matrices  */
double **mata,**matb,**matc;    /* ptrs to the matrices    */

int mdone;                      /* # of rows already done  */
pthread_mutex_t mdone_mutex;    /* lock for mdone          */

/******************************************************
 * return the next row number to do or -1 if we're done
 */
int nextrow()
{
  int res;

  pt_mutex_lock(&mdone_mutex,"mutex lck");     /* lock the mutex      */
  res=mdone++;                                 /* get & incr row cntr */
  pt_mutex_unlock(&mdone_mutex,"mutex unlck"); /* unlock the mutex    */
  return(res<mwork ? res : -1);                /* return -1 if done   */
}

/******************************************************************
 * main code for each thread -- it computes rows of the product a*b
 */
pt_addr_t thread_code(pt_arg_t arg)
{
  int n,p;                              /* n and p counters         */
  int m;                                /* row we're working on     */
  double sum;                           /* temporary var            */

  while ((m=nextrow())!=-1) {           /* m=row to do              */
    for (p=0; p<pwork; p++) {           /* p=col to do, do each col */
      for (sum=0.0,n=0; n<nwork; n++) {
        sum += mata[m][n]*matb[n][p];   /* compute the elt          */
      }
      matc[m][p]=sum;                   /* save result in c         */
    }
  }
  return(NULL);
}

/***********************************
 * this is the main multiply routine
 */
void domult(int m,int n,int p,
	    double **a,                    /* a: m x n matrix      */
	    double **b,                    /* b: n x p matrix      */
	    double **c)                    /* c: m x p matrix=a*b  */
{
  mdone=0; mwork=m; nwork=n; pwork=p;      /* init glb state info  */
  mata=a;  matb=b;  matc=c;

  pt_fork(NTHREADS,thread_code,NULL,NULL); /* run code in parallel */
}

/**************************
 * allocate an m x n matrix
 */
double **newmat(int m,int n)
{
  double **res=(double **) malloc(m*sizeof(double *));
  int i;

  for (i=0; i<m; i++)
    res[i]=(double *) malloc(n*sizeof(double));
  return(res);
}

/**************
 * main program
 */
int main(int argc,char *argv[])
{
  double **a=newmat(2,2),**b=newmat(2,2),**c=newmat(2,2);

  a[0][0]=1.0; a[0][1]=2.0;		    /* initialize a	    */
  a[1][0]=2.0; a[1][1]=1.0;

  b[0][0]=1.0; b[0][1]=2.0;		    /* initialize b	    */
  b[1][0]=3.0; b[1][1]=4.0;

  pt_mutex_init(&mdone_mutex,"init mutex"); /* initialize the mutex */

  domult(2,2,2,a,b,c);		            /* do the mmult: ans is */
  printf("%8.4f%8.4f\n",c[0][0],c[0][1]);   /*   7.0000 10.0000	    */
  printf("%8.4f%8.4f\n",c[1][0],c[1][1]);   /*   5.0000  8.0000	    */

  pt_mutex_destroy(&mdone_mutex,"del mutex");
  return(0);
}

/* EOF mmult.c */

