/* -*- c -*- */ typedef struct { int rows; int cols; double **data; } MATRIX, *MATRIXP; MATRIXP newmat (int rows, int cols) { MATRIXP mat = (MATRIXP)malloc(sizeof(MATRIX)); int i; mat->rows = rows; mat->cols = cols; mat->data = (double**)malloc(sizeof(double*) * rows); for (i = 0; i < rows; ++i) mat->data[i] = (double*)malloc(sizeof(double) * cols); return mat; } MATRIXP copymat (MATRIXP mat) { MATRIXP matnew = newmat(mat->rows, mat->cols); int i, j; for (i = 0; i < mat->rows; ++i) for (j = 0; j < mat->cols; ++j) matnew->data[i][j] = mat->data[i][j]; return matnew; } MATRIXP partmat (MATRIXP mat, int row, int col, int rows, int cols) { MATRIXP matnew = newmat(rows, cols); int i, j; for (i = 0; i < rows; ++i) for (j = 0; j < cols; ++j) matnew->data[i][j] = mat->data[row + i][col + j]; return matnew; } MATRIXP pastemat (MATRIXP mat1, MATRIXP mat2, int row, int col) { int i, j; for (i = 0; i < mat2->rows; ++i) for (j = 0; j < mat2->cols; ++j) mat1->data[row + i][col + j] = mat2->data[i][j]; return mat1; } MATRIXP randmat (int rows, int cols, int lowbound, int upbound) { MATRIXP mat = newmat(rows, cols); int i, j; for (i = 0; i < rows; ++i) for (j = 0; j < cols; ++j) mat->data[i][j] = rand() % (upbound - lowbound + 1) + lowbound; return mat; } MATRIXP addmat (MATRIXP mat1, MATRIXP mat2) { int i, j; for (i = 0; i < mat1->rows; ++i) for (j = 0; j < mat1->cols; ++j) mat1->data[i][j] += mat2->data[i][j]; return mat1; } MATRIXP submat (MATRIXP mat1, MATRIXP mat2) { int i, j; for (i = 0; i < mat1->rows; ++i) for (j = 0; j < mat1->cols; ++j) mat1->data[i][j] -= mat2->data[i][j]; return mat1; } MATRIXP multmattrivial (MATRIXP mat, MATRIXP mat1, MATRIXP mat2) { int i, j, k; for (i = 0; i < mat->rows; ++i) for (j = 0; j < mat->cols; ++j) { mat->data[i][j] = 0; for (k = 0; k < mat1->cols; ++k) mat->data[i][j] += mat1->data[i][k] * mat2->data[k][j]; } } MATRIXP multmatstrassen (MATRIXP mat1, MATRIXP mat2) { int size = mat1->rows, halfsize = size / 2; MATRIXP a11, a12, a21, a22, b11, b12, b21, b22, p1, p2, p3, p4, p5, p6, p7, c11, c12, c21, c22, mat = newmat(size, size); if (size == 1) { mat->data[0][0] = mat1->data[0][0] * mat2->data[0][0]; return mat; } if ((size % 2) == 1) { MATRIXP matp1, matp2, matp; int i, j; matp1 = partmat(mat1, 0, 0, size - 1, size - 1); matp2 = partmat(mat2, 0, 0, size - 1, size - 1); matp = multmatstrassen(matp1, matp2); pastemat(mat, matp, 0, 0); for (i = 0; i < size - 1; ++i) for (j = 0; j < size - 1; ++j) mat->data[i][j] += mat1->data[i][size - 1] * mat2->data[size - 1][j]; for (i = 0; i < size; ++i) { mat->data[i][size - 1] = 0; for (j = 0; j < size; ++j) mat->data[i][size - 1] += mat1->data[i][j] * mat2->data[j][size - 1]; } for (i = 0; i < size - 1; ++i) { mat->data[size - 1][i] = 0; for (j = 0; j < size; ++j) mat->data[size - 1][i] += mat1->data[size - 1][j] * mat2->data[j][i]; } return mat; } a11 = partmat(mat1, 0, 0, halfsize, halfsize); a12 = partmat(mat1, 0, halfsize, halfsize, halfsize); a21 = partmat(mat1, halfsize, 0, halfsize, halfsize); a22 = partmat(mat1, halfsize, halfsize, halfsize, halfsize); b11 = partmat(mat2, 0, 0, halfsize, halfsize); b12 = partmat(mat2, 0, halfsize, halfsize, halfsize); b21 = partmat(mat2, halfsize, 0, halfsize, halfsize); b22 = partmat(mat2, halfsize, halfsize, halfsize, halfsize); p1 = multmatstrassen(addmat(copymat(a11), a22), addmat(copymat(b11), b22)); p2 = multmatstrassen(addmat(copymat(a21), a22), b11); p3 = multmatstrassen(a11, submat(copymat(b12), b22)); p4 = multmatstrassen(a22, submat(copymat(b21), b11)); p5 = multmatstrassen(addmat(copymat(a11), a12), b22); p6 = multmatstrassen(submat(copymat(a21), a11), addmat(copymat(b11), b12)); p7 = multmatstrassen(submat(copymat(a12), a22), addmat(copymat(b21), b22)); c11 = addmat(submat(addmat(copymat(p1), p4), p5), p7); c12 = addmat(copymat(p3), p5); c21 = addmat(copymat(p2), p4); c22 = addmat(submat(addmat(copymat(p1), p3), p2), p6); pastemat(mat, c11, 0, 0); pastemat(mat, c12, 0, halfsize); pastemat(mat, c21, halfsize, 0); pastemat(mat, c22, halfsize, halfsize); return mat; } void printmat (MATRIXP mat) { int i, j; for (i = 0; i < mat->rows; ++i) { for (j = 0; j < mat->cols; ++j) printf("% 6.2f ", mat->data[i][j]); printf("\n"); } } int main (void) { MATRIXP mat1 = randmat(5, 5, -2, 2), mat2 = randmat(5, 5, -2, 2), mat3 = newmat(5, 5), mat4; printf("A:\n"); printmat(mat1); printf("\nB:\n"); printmat(mat2); multmattrivial(mat3, mat1, mat2); mat4 = multmatstrassen(mat1, mat2); printf("\nC trivial:\n"); printmat(mat3); printf("\nC strassen:\n"); printmat(mat4); return 0; }