/*
 * matrices.c
 */

#include <stdio.h>
#include "rtl.h"
#include "tree.h"
#include "predicate.h"
#include "balsched.h"
#include "matrices.h"

#define MULTIPLE(num1,num2) \
  (((num1 / num2) == ((int)(num1 / num2))) ? 1 : 0)

row
row_create (cols)
     int cols;
{
  row r;

  r = (row) malloc (sizeof (struct row_def));
  r->num_cols = cols;
  r->elts = (float *) malloc (cols * sizeof (float));
  bzero (r->elts, cols * sizeof (float));
  return r;
}

void
row_delete (r)
     row r;
{
  free (r->elts);
  free (r);
}

matrix
matrix_create (rows, cols)
     int rows, cols;
{
  matrix m;
  row r;
  int i;

  m = (matrix) malloc (sizeof (struct matrix_def));
  m->num_rows = rows;
  m->num_cols = cols;
  m->rows = (row *) malloc (rows * sizeof (row));
  for (i = 0; i < rows; i++)
    {
      m->rows[i] = row_create (cols);
    }
  return m;
}

void
matrix_delete (m)
     matrix m;
{
  int i;
  for (i = 0; i < m->num_rows; i++)
    row_delete (m->rows[i]);
  free (m);
}

int
matrix_set_elt (m, row, col, value)
     matrix m;
     int row, col;
     float value;
{
  if ((row < 0) || (row >= m->num_rows) || 
      (col < 0) || (col >= m->num_cols))
    return 0;
  else
    {
      m->rows[row]->elts[col] = value;
      return 1;
    }
}

int
matrix_get_elt (m, row, col, value)
     matrix m;
     int row, col;
     float *value;
{
  if ((row < 0) || (row >= m->num_rows) || 
      (col < 0) || (col >= m->num_cols))
    return 0;
  else 
    {
      *value = m->rows[row]->elts[col];
      return 1;
    }
}

void
matrix_print (m)
     matrix m;
{
  int i, j;
  if ((m->num_rows == 0) || (m->num_cols == 0))
    return;

  if (jlo_cache_opt_verbose)
    {
      fprintf (stderr, "----------------------\n");
      for (i = 0; i < m->num_rows; i++)
	{
	  for (j = 0; j < m->num_cols; j++)
	    {
	      fprintf (stderr, " %7.2f", m->rows[i]->elts[j]);
	    }
	  fprintf (stderr, "\n");
	}
      fprintf (stderr, "----------------------\n");
    }
}

void
row_print (r)
     row r;
{
  int j;
  if (r->num_cols == 0)
    return;

  if (jlo_cache_opt_verbose)
    {
      fprintf (stderr, "-------------\n");
      for (j = 0; j < r->num_cols; j++)
	{
	  fprintf (stderr, " %7.2f", r->elts[j]);
	}
      fprintf (stderr, "\n");
      fprintf (stderr, "-------------\n");
    }
}


matrix
find_span (m)
     matrix m;
{
  matrix span;
  
  upper_triangular (m);
  span = matrix_solve (m);
  /* This is actually the solution, but if we take the 
     coefficients for each variable, then we get the span */
  return span;
}

void
upper_triangular (m)
     matrix m;
{
  int col_num, row_num;
  float l;
  float elt1, elt2;
  float pivot1, pivot2;
  int i;
  row temp_row;
  int leftmost;
  float result;
  
  for (col_num = 0; col_num < m->num_cols; col_num++)
    {
      for (row_num = col_num + 1; row_num < m->num_rows; row_num++)
	{
	  if ((elt1 = m->rows[col_num]->elts[col_num]) == 0.0)
	    {
	      /* Zero pivot, so we swap rows */
	      for (i = col_num; 
		   (i < m->num_rows) && !(m->rows[i]->elts[col_num]);
		   i++);
	      
	      if (i == m->num_rows)
		{
		  /* All rows below the pivot have a 0 in this column, 
		     so we can go to the next column. */
		  /* But if any rows below this one have a nonzero
		     element further left than the first nonzero element
		     of this row, then swap as well */
		  leftmost = row_first_nonzero (m->rows[col_num]);
		  for (i = col_num + 1;
		       i < m->num_rows;
		       i++)
		    {
		      result = row_first_nonzero (m->rows[i]);
		      if (result < leftmost)
			{
			  temp_row = m->rows[col_num];
			  m->rows[col_num] = m->rows[i];
			  m->rows[i] = temp_row;
			  /* swap */
			  break;
			}
		    }
		  break;  /* or continue? */
		} else {
		  /* exchange rows */
		  temp_row = m->rows[col_num];
		  m->rows[col_num] = m->rows[i];
		  m->rows[i] = temp_row;
		  /* Get new pivot */
		  elt1 = m->rows[col_num]->elts[col_num];
		}
	    }
	  if (!row_ok (m->rows[row_num], col_num))
	    {
	      elt2 = m->rows[row_num]->elts[col_num]; 
	      if (!elt1 || !elt2)
		{
		  if (jlo_cache_opt_verbose)
		    fprintf (stderr, "Zero element in pivot position\n");
		  continue;
		}
	      l = lcm (elt1, elt2);
	      pivot1 = l / elt1;
	      pivot2 = l / elt2;
	      /* Use row "col_num" to eliminate row "row_num" */
	      row_eliminate (m, col_num, pivot1, row_num, pivot2);
	      if (jlo_cache_opt_verbose)
		fprintf (stderr, "row_num = %d, col_num = %d\n", 
			 row_num, col_num);
	      matrix_print (m);
	    }
	}
    }
}

void
row_eliminate (m, base_row, pivot1, elim_row, pivot2)
     matrix m;
     int base_row, elim_row;
     float pivot1, pivot2;
{
  int i;

  for (i = 0; i < m->num_cols; i++)
    {
      m->rows[elim_row]->elts[i] = (m->rows[elim_row]->elts[i] * pivot2) -
	(m->rows[base_row]->elts[i] * pivot1);
    }
}

matrix
matrix_solve (m)
     matrix m;
{
  matrix solution;
  matrix bindings;
  matrix span;
  int varnum, rownum;
  int i;
  float pivot;
  row temp;

  /* Create one matrix for the bindings of variables to values */

  /* Each row corresponds to the value of a variable's binding.
     In each row, column 0 corresponds to a constant value, and the 
     other columns are the other variables in order */

  bindings = matrix_create (m->num_cols, m->num_cols+1);

  /* Find the free variables, if any */
  for (varnum = 0, rownum = 0; 
       (varnum < m->num_cols) && (rownum < m->num_rows); varnum++)
    {
      if (m->rows[rownum]->elts[varnum] != 0.0)
	{
	  if (all_zeroes_below (m, rownum, varnum))
	    {
	      /* Found a pivot */
	      rownum++;
	    } else {
	      fprintf (stderr, "WARNING: Matrix not upper triangular\n");
	      matrix_print (m);
	    }
	} else {
	  /* No pivot here, so mark variable as free and move to 
	     the next column */
	  bindings->rows[varnum]->elts[varnum+1] = 1.0;
	  if (jlo_cache_opt_verbose2)
	    fprintf (stderr, "Free variable in column %d\n", varnum);
	}
    }

  /* If more columns than rows, then the last columns are free too */
  for (; varnum < m->num_cols; varnum++)
    {
      bindings->rows[varnum]->elts[varnum+1] = 1.0;
      if (jlo_cache_opt_verbose2)
	fprintf (stderr, "Free variable in column %d\n", varnum);
    }

  rownum = m->num_rows -1;
  for (varnum = m->num_cols - 1; (varnum >= 0) && (rownum >= 0); varnum--)
    {
      if (bindings->rows[varnum]->elts[varnum+1] == 0.0)
	{
	  while ((pivot = m->rows[rownum]->elts[varnum]) == 0.0)
	    { 
	      /* this is a zero pivot row, so move up a row to find
		 an ok pivot */
	      rownum--;
	      /* if all entries are 0 in this column, then this variable
		 should be free */
	      if (rownum < 0)
		{
		  fprintf (stderr, 
			   "ERROR: column is all 0s, but it's not free\n");
		  abort ();
		}
	    }
	  /* variable isn't free, so do back substitution */

	  /* Pivot exists, so do back substitution */
	  /* Since we're trying to find the nullspace, if we are at
	     the rightmost column, then we now that that variable
	     should be bound to 0 */
	  for (i = varnum + 1; i < m->num_cols; i++)
/*	  for (i = m->num_cols - 1; i > varnum; i--)*/
	    { 
	      /* For each column to the right of the pivot's column, 
		 get the value of the variable and use it to substitute
		 into this variable's row to obtain it's part of the 
		 solution */

	      temp = row_create (m->num_cols+1);
	      row_copy (temp, bindings->rows[i]);
	      row_scalar_mult (temp, m->rows[rownum]->elts[i]);
	      row_add (bindings->rows[rownum], temp);
	      row_scalar_div (bindings->rows[rownum], -pivot);
	      row_delete (temp);
	    }
	  rownum--;
	} else {
	  /* A free variable */
	  /* Set the binding structure accordingly */
	  /* Should have already been set above */
	  /*	  bindings->rows[varnum]->elts[varnum+1] = 1.0;*/
	  /* for (0,1) (0,0) matrix, this last row will be skipped 
	     since we assume that upper triangular will have echelon
	     form */
	}
    }

  solution = matrix_transpose (bindings);
  matrix_delete (bindings);
  span = span_compress (solution);
  matrix_delete (solution);
  return span;
}

matrix
matrix_transpose (m)
     matrix m;
{
  matrix t;
  int i, j;

  t = matrix_create (m->num_cols, m->num_rows); 
  for (i = 0; i < m->num_rows; i++)
    {
      for (j = 0; j < m->num_cols; j++)
	t->rows[j]->elts[i] = m->rows[i]->elts[j];
    }
  return t; 
}

void
matrix_clear_last_row (m)
     matrix m;
{
  int i;
  for (i = 0; i < m->num_cols; i++)
    m->rows[m->num_rows-1]->elts[i] = 0.0;
}

void 
matrix_remove_complex_rows (m)
     matrix m;
{
  int i, j;
  int found_nonzero = 0;
  int new_num_rows;
  int removing = 0;
  int last_row = 0;

  for (i = 0; i < m->num_rows; i++)
    {
      for (j = 0; j < m->num_cols; j++)
	{
	  if (found_nonzero)
	    {
	      if (m->rows[i]->elts[j] != 0.0)
	      {
		/* remove this row */
		row_delete (m->rows[i]);
		removing = 1;
		m->rows[i] = NULL;
		break;
	      }
	    } else {
	      if (m->rows[i]->elts[j] != 0.0)
		found_nonzero = 1;
	    }
	}
      found_nonzero = 0;
    }

  if (removing)
    {
      if (jlo_cache_opt_verbose2)
	fprintf (stderr, "Compressing to:\n");
      /* Now compress the number of rows */
      new_num_rows = m->num_rows;
      last_row = 0;
      i = 0;
      while (i < m->num_rows)
	{
	  if (m->rows[i] == NULL)
	    {
	      new_num_rows--;
	    } else {
	      m->rows[last_row] = m->rows[i];
	      if (last_row != i)
		m->rows[i] = NULL;
	      last_row++;
	    }
	  i++;
	}
      m->num_rows = new_num_rows;
      matrix_print (m);
    }
}

matrix
matrix_copy (m)
     matrix m;
{
  matrix copy;
  int i, j;

  copy = matrix_create (m->num_rows, m->num_cols);
  for (i = 0; i < m->num_rows; i++)
    for (j = 0; j < m->num_cols; j++)
      copy->rows[i]->elts[j] = m->rows[i]->elts[j]; 
  return copy;
}

int 
all_zeroes_below (m, row_num, col_num)
     matrix m;
     int row_num, col_num;
{
  int i;

  for (i = row_num+1; i < m->num_rows; i++)
    {
      if (m->rows[i]->elts[col_num] != 0.0)
	return 0;
    }
  return 1;
}
 
int
row_ok (r, col_num)
     row r;
     int col_num;
{
  int i;

  for (i = 0; i <= col_num; i++)
    {
      if (r->elts[i] != 0.0)
	return 0;
    }
  return 1;
}

float
lcm (num1, num2)
     float num1, num2;
{
  float result = 1.0;
  float factor = 2.0;

  if (num1 == num2)
    return num1;

  if (num1 > num2)
    {
      if (MULTIPLE (num1, num2))
	return num1;
    } else {
      if (MULTIPLE (num2, num1))
	return num2;
    }

  /* Find all the prime factors of num1 and use them to build up the lcm */
  while (num1 > 1)
    {
      if (MULTIPLE (num1, factor))
	{
	  result *= factor;
	  num1 /= factor;
	  if (MULTIPLE (num2, factor))
	    {
	      num2 /= factor;
	    }
	} else {
	  factor++;
	  /* Could speed it up and take out evens or composite nums */
	}
    }
  factor = 2.0; 

  /* Now find the prime factors in num2 that haven't already been 
     put into result */

  while (num2 > 1.0)
    {
      if (MULTIPLE (num2, factor))
	{
	  result *= factor;
	  num2 /= factor;
	} else {
	  factor++;
	}
    }
  return result;
}

int 
row_copy (dest, src)
     row dest;
     row src;
{
  int i;
  if (dest->num_cols != src->num_cols)
    {
      fprintf (stderr, "Number of columns in src and dest are different.\n");
      return 0;
    } else {
      for (i = 0; i < dest->num_cols; i++)
	dest->elts[i] = src->elts[i];
      return 1;
    }
}

int 
row_scalar_mult (dest, scalar)
     row dest;
     float scalar;
{
  int i;
  
  for (i = 0; i < dest->num_cols; i++)
    {
      dest->elts[i] *= scalar;
    }
  return 1;
}

int 
row_scalar_div (dest, scalar)
     row dest;
     float scalar;
{
  int i;
  
  for (i = 0; i < dest->num_cols; i++)
    {
      dest->elts[i] /= scalar;
#ifdef INTS
      if ((dest->elts[i] MOD scalar) != 0.0)
	fprintf (stderr, "WARNING in row_scalar_div: division results are not integers\n");
#endif
    }
  return 1;
}

int 
row_add (dest, src)
     row dest;
     row src;
{
  int i;
  if (dest->num_cols != src->num_cols)
    {
      fprintf (stderr, "Number of columns in src and dest are different.\n");
      return 0;
    } else {
      for (i = 0; i < dest->num_cols; i++)
	dest->elts[i] += src->elts[i];
      return 1;
    }
}

int 
row_equal_zero (r)
     row r;
{
  int i;
  for (i = 0; i < r->num_cols; i++)
    if (r->elts[i] != 0.0)
      return 0;
  return 1;
}

int
row_first_nonzero (r)
     row r;
{
  int i;

  for (i = 0; i < r->num_cols; i++)
    {
      if (r->elts[i] != 0.0)
	return i;
    }
  return r->num_cols;
}

matrix
span_compress (m)
     matrix m;
{
  matrix span;
  int i;
  int num_rows = 0;
  int rownum;
  
  for (i = 0; i < m->num_rows; i++) 
    {
      if (!row_equal_zero (m->rows[i]))
	  num_rows++;
    }

  span = matrix_create (num_rows, m->num_cols);

  /* Copy rows */
  for (i = 0, rownum = 0; i < m->num_rows; i++) 
    {
      if (!row_equal_zero (m->rows[i]))
	row_copy (span->rows[rownum++], m->rows[i]);
    }
  
  return span;
}

int 
matrix_is_zero (m)
     matrix m;
{
  int i, j;

  if ((m->num_rows == 0) || (m->num_cols == 0))
    return 1;
  for (i = 0; i < m->num_rows; i++)
    for (j = 0; j < m->num_cols; j++)
      if (m->rows[i]->elts[j] != 0.0)
	return 0;
  return 1;
}

