Skip to content

Commit 5d123ef

Browse files
author
hiSandog
committed
docs: Add detailed docstrings explaining Strassen algorithm complexity
- Added comprehensive docstring to strassen() function explaining: - Time complexity: O(n^2.807) vs O(n³) for standard multiplication - Space complexity: O(n²) - The 7 matrix multiplication formulas used - When the algorithm is beneficial (large matrices) - Input requirements (dimensions must be power of 2) - Improved docstring for actual_strassen() with args and raises documentation - Addresses issue #14084: adding better docstring explaining Strassen's complexity
1 parent f527d43 commit 5d123ef

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

divide_and_conquer/strassen_matrix_multiplication.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,28 @@ def print_matrix(matrix: list) -> None:
7373

7474
def actual_strassen(matrix_a: list, matrix_b: list) -> list:
7575
"""
76-
Recursive function to calculate the product of two matrices, using the Strassen
77-
Algorithm. It only supports square matrices of any size that is a power of 2.
76+
Recursive function to calculate the product of two matrices using the Strassen
77+
Algorithm.
78+
79+
This is the core recursive implementation that only supports square matrices
80+
of size that is a power of 2 (e.g., 2x2, 4x4, 8x8, 16x16, etc.).
81+
82+
The algorithm works by:
83+
1. Base case: For 2x2 matrices, use standard multiplication
84+
2. Recursive case:
85+
- Split both matrices into 4 quadrants
86+
- Compute 7 products using the formulas above
87+
- Combine the 7 products to get the final result
88+
89+
Args:
90+
matrix_a: Square matrix with dimensions as power of 2
91+
matrix_b: Square matrix with dimensions as power of 2
92+
93+
Returns:
94+
Product matrix
95+
96+
Raises:
97+
Exception: If matrices are not square or dimensions are not power of 2
7898
"""
7999
if matrix_dimensions(matrix_a) == (2, 2):
80100
return default_matrix_multiplication(matrix_a, matrix_b)
@@ -106,6 +126,42 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list:
106126

107127
def strassen(matrix1: list, matrix2: list) -> list:
108128
"""
129+
Multiplies two matrices using the Strassen algorithm for improved time complexity.
130+
131+
The Strassen algorithm reduces the complexity of matrix multiplication from
132+
O(n³) to O(n^2.807) by reducing the number of recursive matrix multiplications
133+
from 8 to 7. While the asymptotic complexity is better, the actual performance
134+
improvement is typically seen only for large matrices due to higher constant
135+
factors and additional overhead.
136+
137+
Time Complexity: O(n^2.807) - Strassen vs O(n³) for standard multiplication
138+
Space Complexity: O(n²) for storing the result matrix
139+
140+
The algorithm works by recursively dividing matrices into 2x2 submatrices and
141+
computing 7 products (P1-P7) instead of 8:
142+
P1 = A * (F - H)
143+
P2 = (A + B) * H
144+
P3 = (C + D) * E
145+
P4 = D * (G - E)
146+
P5 = (A + D) * (E + H)
147+
P6 = (B - D) * (G + H)
148+
P7 = (A - C) * (E + F)
149+
150+
Then combines these products to get the final result matrix.
151+
152+
Note: This implementation requires input matrices to have dimensions that are
153+
powers of 2. The function automatically pads matrices with zeros if needed.
154+
155+
Args:
156+
matrix1: First matrix (m x n)
157+
matrix2: Second matrix (n x p)
158+
159+
Returns:
160+
Result matrix (m x p)
161+
162+
Raises:
163+
Exception: If matrix dimensions are incompatible for multiplication
164+
109165
>>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
110166
[[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
111167
>>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])

0 commit comments

Comments
 (0)