/* -*- c -*- */

typedef struct
{
    int rows;
    int cols;
    double **data;
} MATRIX, *MATRIXP;

MATRIXP
newmat (int rows, int cols)
{
    MATRIXP mat = (MATRIXP)malloc(sizeof(MATRIX));
    int i;

    mat->rows = rows;
    mat->cols = cols;
    mat->data = (double**)malloc(sizeof(double*) * rows);

    for (i = 0; i < rows; ++i)
	mat->data[i] = (double*)malloc(sizeof(double) * cols);

    return mat;
}

MATRIXP
copymat (MATRIXP mat)
{
    MATRIXP matnew = newmat(mat->rows, mat->cols);
    int i, j;

    for (i = 0; i < mat->rows; ++i)
	for (j = 0; j < mat->cols; ++j)
	    matnew->data[i][j] = mat->data[i][j];

    return matnew;
}

MATRIXP
partmat (MATRIXP mat, int row, int col, int rows, int cols)
{
    MATRIXP matnew = newmat(rows, cols);
    int i, j;

    for (i = 0; i < rows; ++i)
	for (j = 0; j < cols; ++j)
	    matnew->data[i][j] = mat->data[row + i][col + j];

    return matnew;
}

MATRIXP
pastemat (MATRIXP mat1, MATRIXP mat2, int row, int col)
{
    int i, j;

    for (i = 0; i < mat2->rows; ++i)
	for (j = 0; j < mat2->cols; ++j)
	    mat1->data[row + i][col + j] = mat2->data[i][j];

    return mat1;
}

MATRIXP
randmat (int rows, int cols, int lowbound, int upbound)
{
    MATRIXP mat = newmat(rows, cols);
    int i, j;

    for (i = 0; i < rows; ++i)
	for (j = 0; j < cols; ++j)
	    mat->data[i][j] = rand() % (upbound - lowbound + 1) + lowbound;

    return mat;
}

MATRIXP
addmat (MATRIXP mat1, MATRIXP mat2)
{
    int i, j;

    for (i = 0; i < mat1->rows; ++i)
	for (j = 0; j < mat1->cols; ++j)
	    mat1->data[i][j] += mat2->data[i][j];

    return mat1;
}

MATRIXP
submat (MATRIXP mat1, MATRIXP mat2)
{
    int i, j;

    for (i = 0; i < mat1->rows; ++i)
	for (j = 0; j < mat1->cols; ++j)
	    mat1->data[i][j] -= mat2->data[i][j];

    return mat1;
}

MATRIXP
multmattrivial (MATRIXP mat, MATRIXP mat1, MATRIXP mat2)
{
    int i, j, k;

    for (i = 0; i < mat->rows; ++i)
	for (j = 0; j < mat->cols; ++j)
	{
	    mat->data[i][j] = 0;

	    for (k = 0; k < mat1->cols; ++k)
		mat->data[i][j] += mat1->data[i][k] * mat2->data[k][j];
	}
}

MATRIXP
multmatstrassen (MATRIXP mat1, MATRIXP mat2)
{
    int size =  mat1->rows,
	halfsize = size / 2;
    MATRIXP a11, a12, a21, a22,
	b11, b12, b21, b22,
	p1, p2, p3, p4, p5, p6, p7,
	c11, c12, c21, c22,
	mat = newmat(size, size);

    if (size == 1)
    {
	mat->data[0][0] = mat1->data[0][0] * mat2->data[0][0];

	return mat;
    }

    if ((size % 2) == 1)
    {
	MATRIXP matp1, matp2, matp;
	int i, j;

	matp1 = partmat(mat1, 0, 0, size - 1, size - 1);
	matp2 = partmat(mat2, 0, 0, size - 1, size - 1);
	
	matp = multmatstrassen(matp1, matp2);
	pastemat(mat, matp, 0, 0);

	for (i = 0; i < size - 1; ++i)
	    for (j = 0; j < size - 1; ++j)
		mat->data[i][j] += mat1->data[i][size - 1] *
		                   mat2->data[size - 1][j];
	for (i = 0; i < size; ++i)
	{
	    mat->data[i][size - 1] = 0;
	    for (j = 0; j < size; ++j)
		mat->data[i][size - 1] += mat1->data[i][j] *
		                          mat2->data[j][size - 1];
	}
	for (i = 0; i < size - 1; ++i)
	{
	    mat->data[size - 1][i] = 0;
	    for (j = 0; j < size; ++j)
		mat->data[size - 1][i] += mat1->data[size - 1][j] *
		                          mat2->data[j][i];
	}

	return mat;
    }

    a11 = partmat(mat1, 0, 0, halfsize, halfsize);
    a12 = partmat(mat1, 0, halfsize, halfsize, halfsize);
    a21 = partmat(mat1, halfsize, 0, halfsize, halfsize);
    a22 = partmat(mat1, halfsize, halfsize, halfsize, halfsize);

    b11 = partmat(mat2, 0, 0, halfsize, halfsize);
    b12 = partmat(mat2, 0, halfsize, halfsize, halfsize);
    b21 = partmat(mat2, halfsize, 0, halfsize, halfsize);
    b22 = partmat(mat2, halfsize, halfsize, halfsize, halfsize);

    p1 = multmatstrassen(addmat(copymat(a11), a22),
			 addmat(copymat(b11), b22));
    p2 = multmatstrassen(addmat(copymat(a21), a22),
			 b11);
    p3 = multmatstrassen(a11,
			 submat(copymat(b12), b22));
    p4 = multmatstrassen(a22,
			 submat(copymat(b21), b11));
    p5 = multmatstrassen(addmat(copymat(a11), a12),
			 b22);
    p6 = multmatstrassen(submat(copymat(a21), a11),
			 addmat(copymat(b11), b12));
    p7 = multmatstrassen(submat(copymat(a12), a22),
			 addmat(copymat(b21), b22));

    c11 = addmat(submat(addmat(copymat(p1), p4), p5), p7);
    c12 = addmat(copymat(p3), p5);
    c21 = addmat(copymat(p2), p4);
    c22 = addmat(submat(addmat(copymat(p1), p3), p2), p6);

    pastemat(mat, c11, 0, 0);
    pastemat(mat, c12, 0, halfsize);
    pastemat(mat, c21, halfsize, 0);
    pastemat(mat, c22, halfsize, halfsize);

    return mat;
}

void
printmat (MATRIXP mat)
{
    int i, j;

    for (i = 0; i < mat->rows; ++i)
    {
	for (j = 0; j < mat->cols; ++j)
	    printf("% 6.2f ", mat->data[i][j]);
	printf("\n");
    }
}

int
main (void)
{
    MATRIXP mat1 = randmat(5, 5, -2, 2),
	mat2 = randmat(5, 5, -2, 2),
	mat3 = newmat(5, 5),
	mat4;

    printf("A:\n"); printmat(mat1);
    printf("\nB:\n"); printmat(mat2);

    multmattrivial(mat3, mat1, mat2);
    mat4 = multmatstrassen(mat1, mat2);

    printf("\nC trivial:\n"); printmat(mat3);
    printf("\nC strassen:\n"); printmat(mat4);

    return 0;
}
