Strassen矩阵乘法在Python中的实现:

huangapple go评论144阅读模式
英文:

Strassen matrix multiplication in python

问题

  1. def matrix_addition(A, B):
  2. # Check if matrices have the same size
  3. if len(A) != len(B) or len(A[0]) != len(B[0]):
  4. raise ValueError("Matrices must have the same size")
  5. # Initialize result matrix with zeros
  6. result = [[0 for col in range(len(A[0]))] for row in range(len(A))]
  7. # Add matrices element-wise
  8. for row in range(len(A)):
  9. for col in range(len(A[0])):
  10. result[row][col] = A[row][col] + B[row][col]
  11. return result
  12. def matrix_subtraction(A, B):
  13. # Check if matrices have the same size
  14. if len(A) != len(B) or len(A[0]) != len(B[0]):
  15. raise ValueError("Matrices must have the same size")
  16. # Initialize result matrix with zeros
  17. result = [[0 for col in range(len(A[0]))] for row in range(len(A))]
  18. # Subtract matrices element-wise
  19. for row in range(len(A)):
  20. for col in range(len(A[0])):
  21. result[row][col] = A[row][col] - B[row][col]
  22. return result
  23. def strassen(a, b):
  24. if len(a) == 1 and len(a[0]) == 1:
  25. return [[a[0][0] * b[0][0]]]
  26. # Check if matrices have the same size
  27. if len(a) != len(a[0]) or len(b) != len(b[0]) or len(a) != len(b):
  28. raise ValueError("Matrices must be square and of the same size")
  29. # Check if the size is a power of 2
  30. size = len(a)
  31. if size & (size - 1) != 0:
  32. raise ValueError("Matrix size must be a power of 2")
  33. # Function to divide a matrix into four submatrices
  34. def divide(matrix):
  35. size = len(matrix)
  36. mid = size // 2
  37. quad1 = [row[:mid] for row in matrix[:mid]]
  38. quad2 = [row[mid:] for row in matrix[:mid]]
  39. quad3 = [row[:mid] for row in matrix[mid:]]
  40. quad4 = [row[mid:] for row in matrix[mid:]]
  41. return quad1, quad2, quad3, quad4
  42. # Function to combine four submatrices into a single matrix
  43. def combine_submatrices(quad1, quad2, quad3, quad4):
  44. size = len(quad1)
  45. mid = size
  46. result = [[0 for _ in range(2 * mid)] for _ in range(2 * mid)]
  47. for i in range(mid):
  48. for j in range(mid):
  49. result[i][j] = quad1[i][j]
  50. result[i][j + mid] = quad2[i][j]
  51. result[i + mid][j] = quad3[i][j]
  52. result[i + mid][j + mid] = quad4[i][j]
  53. return result
  54. quad1_a, quad2_a, quad3_a, quad4_a = divide(a)
  55. quad1_b, quad2_b, quad3_b, quad4_b = divide(b)
  56. p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
  57. p2 = strassen(matrix_addition(quad3_a, quad4_a), quad1_b)
  58. p3 = strassen(quad1_a, matrix_subtraction(quad2_b, quad4_b))
  59. p4 = strassen(quad4_a, matrix_subtraction(quad3_b, quad1_b))
  60. p5 = strassen(matrix_addition(quad1_a, quad2_a), quad4_b)
  61. p6 = strassen(matrix_subtraction(quad3_a, quad1_a), matrix_addition(quad1_b, quad2_b))
  62. p7 = strassen(matrix_subtraction(quad2_a, quad4_a), matrix_addition(quad3_b, quad4_b))
  63. final_quad1 = matrix_subtraction(
  64. matrix_addition(p1, p4),
  65. matrix_addition(
  66. matrix_subtraction(
  67. matrix_addition(p5, p7),
  68. p2
  69. ),
  70. p6
  71. )
  72. )
  73. final_quad2 = matrix_addition(p3, p5)
  74. final_quad3 = matrix_addition(p2, p4)
  75. final_quad4 = matrix_subtraction(
  76. matrix_addition(p1, p3),
  77. matrix_addition(
  78. matrix_subtraction(p5, p2),
  79. p7
  80. )
  81. )
  82. return combine_submatrices(final_quad1, final_quad2, final_quad3, final_quad4)
  83. A = [[1, 2, 3, 4],
  84. [5, 6, 7, 8],
  85. [9, 10, 11, 12],
  86. [13, 14, 15, 16]]
  87. B = [[17, 18, 19, 20],
  88. [21, 22, 23, 24],
  89. [25, 26, 27, 28],
  90. [29, 30, 31, 32]]
  91. result = strassen(A, B)
  92. for row in result:
  93. print(row)

I've made some modifications to your code to address the issues you were facing. This code should perform matrix multiplication using the Strassen algorithm for 2^n sized matrices and print the result as you expected.

英文:
  1. def matrix_addition(A, B):
  2. # Check if matrices have the same size
  3. if len(A) != len(B) or len(A[0]) != len(B[0]):
  4. raise ValueError("Matrices must have the same size")
  5. # Initialize result matrix with zeros
  6. result = [[0 for col in range(len(A[0]))] for row in range(len(A))]
  7. # Add matrices element-wise
  8. for row in range(len(A)):
  9. for col in range(len(A[0])):
  10. result[row][col] = A[row][col] + B[row][col]
  11. return result
  12. def matrix_subtraction(A, B):
  13. # Check if matrices have the same size
  14. if len(A) != len(B) or len(A[0]) != len(B[0]):
  15. raise ValueError("Matrices must have the same size")
  16. # Initialize result matrix with zeros
  17. result = [[0 for col in range(len(A[0]))] for row in range(len(A))]
  18. # Subtract matrices element-wise
  19. for row in range(len(A)):
  20. for col in range(len(A[0])):
  21. result[row][col] = A[row][col] - B[row][col]
  22. return result
  23. def strassen(a, b):
  24. if len(A) == 1 and len(A[0]) == 1:
  25. return a[0][0] * b[0][0]
  26. else:
  27. # divide into quadrants
  28. quad1_a, quad2_a, quad3_a, quad4_a = divide(a)
  29. quad1_b, quad2_b, quad3_b, quad4_b = divide(b)
  30. #break into parts to compute
  31. p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
  32. p2 = strassen(matrix_addition(quad3_a + quad4_a), quad1_b)
  33. p3 = strassen(quad1_a, matrix_subtraction(quad3_b, quad1_b))
  34. p4 = strassen(quad4_a, matrix_subtraction(quad3_b, quad1_b))
  35. p5 = strassen(matrix_addition(quad1_a, quad2_a), quad4_b)
  36. p6 = strassen(matrix_subtraction(quad3_a, quad1_a), matrix_addition(quad1_b, quad2_b))
  37. p7 = strassen(matrix_subtraction(quad2_a, quad4_a), matrix_addition(quad3_b, quad4_b))
  38. # create the final matrix
  39. final_quad1 = matrix_subtraction(matrix_addition(p1, p4), matrix_addition(p5, p7))
  40. final_quad2 = matrix_addition(p3, p5)
  41. final_quad3 = matrix_addition(p2, p4)
  42. final_quad4 = matrix_addition(matrix_subtraction(p1, p2), matrix_addition(p3, p6))
  43. resultant_matrix = combine_submatrices(final_quad1, final_quad2, final_quad3, final_quad4)
  44. return resultant_matrix

its a basic implementation of the strassen algorithm i have tested all the secondary function and they work but joined together i keep running into problems.

the strassen function is supposed to take 2 2d arrays of 2^n size for the above code i used the arrays

  1. A = [[1, 2, 3, 4],
  2. [5, 6, 7, 8],
  3. [9, 10, 11, 12],
  4. [13, 14, 15, 16]]
  5. B = [[17, 18, 19, 20],
  6. [21, 22, 23, 24],
  7. [25, 26, 27, 28],
  8. [29, 30, 31, 32]]

the result should be this

  1. C = [[250, 260, 270, 280],
  2. [618, 644, 670, 696],
  3. [986, 1028, 1070, 1112],
  4. [1354, 1412, 1470, 1528]]

i have ran the code multiple times and i get into the problem

  1. raise ValueError("Matrices must have the same size")
  2. ValueError: Matrices must have the same size
  3. Process finished with exit code 1

if i turn the exception code off i run into a different problem

  1. Traceback (most recent call last):
  2. File "strassen_matrix.py", line 112, in <module>
  3. print(strassen(A, B))
  4. ^^^^^^^^^^^^^^
  5. File "strassen_matrix.py", line 97, in strassen
  6. p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
  7. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  8. File "strassen_matrix.py", line 97, in strassen
  9. p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
  10. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  11. File "strassen_matrix.py", line 97, in strassen
  12. p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
  13. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  14. [Previous line repeated 994 more times]
  15. File "strassen_matrix.py", line 95, in strassen
  16. quad1_a, quad2_a, quad3_a, quad4_a = divide(a)
  17. ^^^^^^^^^
  18. File "strassen_matrix.py", line 20, in divide
  19. for x in range(int(len(matrix) / 2)):
  20. ^^^^^^^^^^^^^^^^^^^^
  21. RecursionError: maximum recursion depth exceeded while calling a Python objectquer

i tried increasing the recursion limit aswell but still same issue im stumped as to how to fix this any help is appreciated

答案1

得分: 2

  • if len(A) == 1 and len(A[0]) == 1 引用了全局名称 AB,而不是参数名称 ab(下一条语句也是如此)。通过将主要逻辑也放入函数中,使其 AB 也成为局部变量,避免名称混淆。

  • return A[0][0] * B[0][0] 返回一个数字,而调用者期望得到一个矩阵。因此(连同上面的修复),应该是

    1. return [[a[0][0] * b[0][0]]]
  • p2 的公式中,你有一个拼写错误:matrix_addition(quad3_a + quad4_a) 应该是:

    1. matrix_addition(quad3_a, quad4_a)
  • p3 = strassen(quad1_a, matrix_subtraction(quad3_b, quad1_b)) 用相反的符号计算了 p3。你会想要

    1. p3 = strassen(quad1_a, matrix_subtraction(quad3_b, quad1_b))
  • final_quad1 = matrix_subtraction(matrix_addition(p1, p4), matrix_addition(p5, p7)) 执行了两个项的减法,而不是一个。修正为

    1. final_quad1 = matrix_addition(matrix_subtraction(matrix_addition(p1, p4), p5), p7)

修复这些问题后,它将正常工作。

英文:

You have a few errors:

  • if len(A) == 1 and len(A[0]) == 1 references the global names A and B instead of the parameter names a and b (also in next statement). Avoid such confusion of names by putting the main logic also in a function, so that its A and B are also local.

  • return A[0][0] * B[0][0] returns a number, while the caller expects a matrix. So (together with above fix), it should be

    1. return [[a[0][0] * b[0][0]]]
  • In the formula for p2 you have a typo: matrix_addition(quad3_a + quad4_a) should be:

    1. matrix_addition(quad3_a, quad4_a)
  • p3 = strassen(quad1_a, matrix_subtraction(quad3_b, quad1_b)) calculates p3 with the opposite sign. You'll want

    1. p3 = strassen(quad1_a, matrix_subtraction(quad3_b, quad1_b))
  • final_quad1 = matrix_subtraction(matrix_addition(p1, p4), matrix_addition(p5, p7)) performs a subtraction of two terms instead of just one. Correct to

    1. final_quad1 = matrix_addition(matrix_subtraction(matrix_addition(p1, p4), p5), p7)

After fixing those issues, it will work.

huangapple
  • 本文由 发表于 2023年7月23日 19:56:58
  • 转载请务必保留本文链接:https://go.coder-hub.com/76748105.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定