Simple Divide and Conquer also leads to O(N3), can there be a better way?
In the above divide and conquer method, the main component for high time complexity is 8 recursive calls. The idea of Strassen’s method is to reduce the number of recursive calls to 7. Strassen’s method is similar to above simple divide and conquer method in the sense that this method also divide matrices to sub-matrices of size N/2 x N/2 as shown in the above diagram, but in Strassen’s method, the four sub-matrices of result are calculated using following formulae.
In the above divide and conquer method, the main component for high time complexity is 8 recursive calls. The idea of Strassen’s method is to reduce the number of recursive calls to 7. Strassen’s method is similar to above simple divide and conquer method in the sense that this method also divide matrices to sub-matrices of size N/2 x N/2 as shown in the above diagram, but in Strassen’s method, the four sub-matrices of result are calculated using following formulae.
Time Complexity of Strassen’s Method
Addition and Subtraction of two matrices takes O(N2) time. So time complexity can be written as
Addition and Subtraction of two matrices takes O(N2) time. So time complexity can be written as
T(N) = 7T(N/2) + O(N2)
From Master's Theorem, time complexity of above method is
O(NLog7) which is approximately O(N2.8074)
Generally Strassen’s Method is not preferred for practical applications for following reasons.
1) The constants used in Strassen’s method are high and for a typical application Naive method works better.
2) For Sparse matrices, there are better methods especially designed for them.
3) The submatrices in recursion take extra space.
4) Because of the limited precision of computer arithmetic on noninteger values, larger errors accumulate in Strassen’s algorithm than in Naive Method (Source: CLRS Book)
1) The constants used in Strassen’s method are high and for a typical application Naive method works better.
2) For Sparse matrices, there are better methods especially designed for them.
3) The submatrices in recursion take extra space.
4) Because of the limited precision of computer arithmetic on noninteger values, larger errors accumulate in Strassen’s algorithm than in Naive Method (Source: CLRS Book)
Divide and Conquer
1) Divide matrices A and B in 4 sub-matrices of size N/2 x N/2 as shown in the below diagram.
2) Calculate following values recursively. ae + bg, af + bh, ce + dg and cf + dh.
1) Divide matrices A and B in 4 sub-matrices of size N/2 x N/2 as shown in the below diagram.
2) Calculate following values recursively. ae + bg, af + bh, ce + dg and cf + dh.
In the above method, we do 8 multiplications for matrices of size N/2 x N/2 and 4 additions. Addition of two matrices takes O(N2) time. So the time complexity can be written as
T(N) = 8T(N/2) + O(N2)
From Master's Theorem, time complexity of above method is O(N3)
Naive Method
void multiply( int A[][N], int B[][N], int C[][N]) { for ( int i = 0; i < N; i++) { for ( int j = 0; j < N; j++) { C[i][j] = 0; for ( int k = 0; k < N; k++) { C[i][j] += A[i][k]*B[k][j]; } } } } |
Time Complexity of above method is O(N3).
public static int[][] ikjAlgorithm(int[][] A, int[][] B) {
int n = A.length;
// initialise C
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int k = 0; k < n; k++) {
for (int j = 0; j < n; j++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
private static int nextPowerOfTwo(int n) {
int log2 = (int) Math.ceil(Math.log(n) / Math.log(2));
return (int) Math.pow(2, log2);
}
public static int[][] strassen(ArrayList<ArrayList<Integer>> A,
ArrayList<ArrayList<Integer>> B) {
// Make the matrices bigger so that you can apply the strassen
// algorithm recursively without having to deal with odd
// matrix sizes
int n = A.size();
int m = nextPowerOfTwo(n);
int[][] APrep = new int[m][m];
int[][] BPrep = new int[m][m];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
APrep[i][j] = A.get(i).get(j);
BPrep[i][j] = B.get(i).get(j);
}
}
int[][] CPrep = strassenR(APrep, BPrep);
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = CPrep[i][j];
}
}
return C;
}
private static int[][] strassenR(int[][] A, int[][] B) {
int n = A.length;
if (n <= LEAF_SIZE) {
return ikjAlgorithm(A, B);
} else {
// initializing the new sub-matrices
int newSize = n / 2;
int[][] a11 = new int[newSize][newSize];
int[][] a12 = new int[newSize][newSize];
int[][] a21 = new int[newSize][newSize];
int[][] a22 = new int[newSize][newSize];
int[][] b11 = new int[newSize][newSize];
int[][] b12 = new int[newSize][newSize];
int[][] b21 = new int[newSize][newSize];
int[][] b22 = new int[newSize][newSize];
int[][] aResult = new int[newSize][newSize];
int[][] bResult = new int[newSize][newSize];
// dividing the matrices in 4 sub-matrices:
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
a11[i][j] = A[i][j]; // top left
a12[i][j] = A[i][j + newSize]; // top right
a21[i][j] = A[i + newSize][j]; // bottom left
a22[i][j] = A[i + newSize][j + newSize]; // bottom right
b11[i][j] = B[i][j]; // top left
b12[i][j] = B[i][j + newSize]; // top right
b21[i][j] = B[i + newSize][j]; // bottom left
b22[i][j] = B[i + newSize][j + newSize]; // bottom right
}
}
// Calculating p1 to p7:
aResult = add(a11, a22);
bResult = add(b11, b22);
int[][] p1 = strassenR(aResult, bResult);
// p1 = (a11+a22) * (b11+b22)
aResult = add(a21, a22); // a21 + a22
int[][] p2 = strassenR(aResult, b11); // p2 = (a21+a22) * (b11)
bResult = subtract(b12, b22); // b12 - b22
int[][] p3 = strassenR(a11, bResult);
// p3 = (a11) * (b12 - b22)
bResult = subtract(b21, b11); // b21 - b11
int[][] p4 = strassenR(a22, bResult);
// p4 = (a22) * (b21 - b11)
aResult = add(a11, a12); // a11 + a12
int[][] p5 = strassenR(aResult, b22);
// p5 = (a11+a12) * (b22)
aResult = subtract(a21, a11); // a21 - a11
bResult = add(b11, b12); // b11 + b12
int[][] p6 = strassenR(aResult, bResult);
// p6 = (a21-a11) * (b11+b12)
aResult = subtract(a12, a22); // a12 - a22
bResult = add(b21, b22); // b21 + b22
int[][] p7 = strassenR(aResult, bResult);
// p7 = (a12-a22) * (b21+b22)
// calculating c21, c21, c11 e c22:
int[][] c12 = add(p3, p5); // c12 = p3 + p5
int[][] c21 = add(p2, p4); // c21 = p2 + p4
aResult = add(p1, p4); // p1 + p4
bResult = add(aResult, p7); // p1 + p4 + p7
int[][] c11 = subtract(bResult, p5);
// c11 = p1 + p4 - p5 + p7
aResult = add(p1, p3); // p1 + p3
bResult = add(aResult, p6); // p1 + p3 + p6
int[][] c22 = subtract(bResult, p2);
// c22 = p1 + p3 - p2 + p6
// Grouping the results obtained in a single matrix:
int[][] C = new int[n][n];
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
C[i][j] = c11[i][j];
C[i][j + newSize] = c12[i][j];
C[i + newSize][j] = c21[i][j];
C[i + newSize][j + newSize] = c22[i][j];
}
}
return C;
}
}
Read full article from Divide and Conquer | Set 5 (Strassen's Matrix Multiplication) | GeeksforGeeks