23
23
24
24
import numpy as np
25
25
26
- from papyrus .utils .matrix_utils import compute_hermitian_eigensystem
26
+ from papyrus .utils .matrix_utils import (
27
+ compute_hermitian_eigensystem ,
28
+ normalize_gram_matrix ,
29
+ )
27
30
28
31
29
32
def compute_trace (matrix : np .ndarray , normalize : bool = False ) -> np .ndarray :
@@ -77,7 +80,10 @@ def compute_shannon_entropy(dist: np.ndarray, effective: bool = False) -> float:
77
80
78
81
79
82
def compute_von_neumann_entropy (
80
- matrix : np .ndarray , effective : bool = True , normalize_eig : bool = True
83
+ matrix : np .ndarray ,
84
+ effective : bool = True ,
85
+ normalize_eig : bool = True ,
86
+ normalize_matrix : bool = False ,
81
87
) -> float :
82
88
"""
83
89
Compute the von-Neumann entropy of a matrix.
@@ -91,12 +97,19 @@ def compute_von_neumann_entropy(
91
97
the system thereby returning the effective entropy.
92
98
normalize_eig : bool (default = True)
93
99
If true, the eigenvalues are scaled to look like probabilities.
100
+ normalize_matrix : bool (default=False)
101
+ If true, the NTK is normalized by the square root of the product of the
102
+ corresponding diagonal elements. This is equivalent to normalizing the
103
+ gradient vectors forming the NTK.
94
104
95
105
Returns
96
106
-------
97
107
entropy : float
98
108
Von-Neumann entropy of the matrix.
99
109
"""
110
+ if normalize_matrix :
111
+ matrix = normalize_gram_matrix (matrix )
112
+
100
113
eigvals , _ = compute_hermitian_eigensystem (matrix , normalize = normalize_eig )
101
114
102
115
entropy = compute_shannon_entropy (eigvals )
0 commit comments