Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions divide_and_conquer/strassen_matrix_multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,28 @@ def print_matrix(matrix: list) -> None:

def actual_strassen(matrix_a: list, matrix_b: list) -> list:
"""
Recursive function to calculate the product of two matrices, using the Strassen
Algorithm. It only supports square matrices of any size that is a power of 2.
Recursive function to calculate the product of two matrices using the Strassen
Algorithm.

This is the core recursive implementation that only supports square matrices
of size that is a power of 2 (e.g., 2x2, 4x4, 8x8, 16x16, etc.).

The algorithm works by:
1. Base case: For 2x2 matrices, use standard multiplication
2. Recursive case:
- Split both matrices into 4 quadrants
- Compute 7 products using the formulas above
- Combine the 7 products to get the final result

Args:
matrix_a: Square matrix with dimensions as power of 2
matrix_b: Square matrix with dimensions as power of 2

Returns:
Product matrix

Raises:
Exception: If matrices are not square or dimensions are not power of 2
"""
if matrix_dimensions(matrix_a) == (2, 2):
return default_matrix_multiplication(matrix_a, matrix_b)
Expand Down Expand Up @@ -106,6 +126,42 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list:

def strassen(matrix1: list, matrix2: list) -> list:
"""
Multiplies two matrices using the Strassen algorithm for improved time complexity.

The Strassen algorithm reduces the complexity of matrix multiplication from
O(n³) to O(n^2.807) by reducing the number of recursive matrix multiplications
from 8 to 7. While the asymptotic complexity is better, the actual performance
improvement is typically seen only for large matrices due to higher constant
factors and additional overhead.

Time Complexity: O(n^2.807) - Strassen vs O(n³) for standard multiplication
Space Complexity: O(n²) for storing the result matrix

The algorithm works by recursively dividing matrices into 2x2 submatrices and
computing 7 products (P1-P7) instead of 8:
P1 = A * (F - H)
P2 = (A + B) * H
P3 = (C + D) * E
P4 = D * (G - E)
P5 = (A + D) * (E + H)
P6 = (B - D) * (G + H)
P7 = (A - C) * (E + F)

Then combines these products to get the final result matrix.

Note: This implementation requires input matrices to have dimensions that are
powers of 2. The function automatically pads matrices with zeros if needed.

Args:
matrix1: First matrix (m x n)
matrix2: Second matrix (n x p)

Returns:
Result matrix (m x p)

Raises:
Exception: If matrix dimensions are incompatible for multiplication

>>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
[[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
>>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])
Expand Down