Exercise 4.2.2
Write pseudocode for Strassen's algorithm
Pseudocode is for sissies. Let's write it in C!
60 minutes later: OK, bad decision.
C code
#include <alloca.h> #include <stdio.h> // The matrix representation. One structure to fit both a matrix and a // submatrix. typedef struct { int x; int y; int size; int original_size; int *data; } matrix; // Functions to index matrices int get(matrix m, int x, int y) { return m.data[m.original_size * (m.x + x) + m.y + y]; }; void put(matrix m, int x, int y, int value) { m.data[m.original_size * (m.x + x) + m.y + y] = value; }; // Matrix building matrix create_matrix(int size, int *data) { matrix result; result.x = 0; result.y = 0; result.size = size; result.original_size = size; result.data = data; return result; } matrix submatrix(matrix A, int x, int y, int size) { matrix result; result.x = A.x + x; result.y = A.y + y; result.size = size; result.original_size = A.original_size; result.data = A.data; return result; } #define INIT_ON_STACK(m_, size_) \ m_.x = 0; \ m_.y = 0; \ m_.size = size_; \ m_.original_size = size_; \ m_.data = alloca(size_ * size_ * sizeof(int)); // Adding and subtracting matrices void plus(matrix C, matrix A, matrix B) { for (int i = 0; i < C.size; i++) { for (int j = 0; j < C.size; j++) { put(C, i, j, get(A, i, j) + get(B, i, j)); } } } void minus(matrix C, matrix A, matrix B) { for (int i = 0; i < C.size; i++) { for (int j = 0; j < C.size; j++) { put(C, i, j, get(A, i, j) - get(B, i, j)); } } } void add(matrix T, matrix S) { for (int i = 0; i < T.size; i++) { for (int j = 0; j < T.size; j++) { put(T, i, j, get(T, i, j) + get(S, i, j)); } } } void sub(matrix T, matrix S) { for (int i = 0; i < T.size; i++) { for (int j = 0; j < T.size; j++) { put(T, i, j, get(T, i, j) - get(S, i, j)); } } } void zero(matrix m) { for (int i = 0; i < m.size; i++) { for (int j = 0; j < m.size; j++) { put(m, i, j, 0); } } } // A function to print matrices void print_matrix(matrix m) { printf("%dx%d (+%d+%d) (%d)\n", m.size, m.size, m.x, m.y, m.original_size); printf("==============\n"); for (int i = 0; i < m.size; i++) { for (int j = 0; j < m.size; j++) { printf("%4d", get(m, i, j)); } printf("\n"); } printf("\n"); } // Strassen's algorithm void strassen(matrix C, matrix A, matrix B) { int size = A.size, half = size / 2; if (A.size == 1) { put(C, 0, 0, get(A, 0, 0) * get(B, 0, 0)); } else { matrix s1, s2, s3, s4, s5, s6, s7, s8, s9, s10; matrix p1, p2, p3, p4, p5, p6, p7; INIT_ON_STACK(s1, half); INIT_ON_STACK(s2, half); INIT_ON_STACK(s3, half); INIT_ON_STACK(s4, half); INIT_ON_STACK(s5, half); INIT_ON_STACK(s6, half); INIT_ON_STACK(s7, half); INIT_ON_STACK(s8, half); INIT_ON_STACK(s9, half); INIT_ON_STACK(s10, half); INIT_ON_STACK(p1, half); INIT_ON_STACK(p2, half); INIT_ON_STACK(p3, half); INIT_ON_STACK(p4, half); INIT_ON_STACK(p5, half); INIT_ON_STACK(p6, half); INIT_ON_STACK(p7, half); matrix a11 = submatrix(A, 0, 0, half); matrix a12 = submatrix(A, 0, half, half); matrix a21 = submatrix(A, half, 0, half); matrix a22 = submatrix(A, half, half, half); matrix b11 = submatrix(B, 0, 0, half); matrix b12 = submatrix(B, 0, half, half); matrix b21 = submatrix(B, half, 0, half); matrix b22 = submatrix(B, half, half, half); matrix c11 = submatrix(C, 0, 0, half); matrix c12 = submatrix(C, 0, half, half); matrix c21 = submatrix(C, half, 0, half); matrix c22 = submatrix(C, half, half, half); minus(s1, b12, b22); plus(s2, a11, a12); plus(s3, a21, a22); minus(s4, b21, b11); plus(s5, a11, a22); plus(s6, b11, b22); minus(s7, a12, a22); plus(s8, b21, b22); minus(s9, a11, a21); plus(s10, b11, b12); strassen(p1, a11, s1); strassen(p2, s2, b22); strassen(p3, s3, b11); strassen(p4, a22, s4); strassen(p5, s5, s6); strassen(p6, s7, s8); strassen(p7, s9, s10); zero(c11); zero(c12); zero(c21); zero(c22); add(c11, p5); add(c11, p4); sub(c11, p2); add(c11, p6); add(c12, p1); add(c12, p2); add(c21, p3); add(c21, p4); add(c22, p5); add(c22, p1); sub(c22, p3); sub(c22, p7); } }