-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Marc record tensornetwork entropy #94
Open
SamTov
wants to merge
28
commits into
main
Choose a base branch
from
Marc_record_tensornetwork_entropy
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
d253449
changed loss derivative to be a vector
5e3f8f5
Implemented the recording of the fisher trace
928fbc7
Implemented the recording of the fisher trace
ffbacbf
fixed a commit issue after ammending to a already pushed branch
1f6888d
fixing black formatting
1b386ab
somehow i did every change twice, this commit removed the doubles
e553e68
changes
SamTov 1984382
optimized fisher trace calculation
fac41d8
black formatting again
f3c64b5
and again
49a33e0
Move fisher computation to own module.
SamTov 14a8e7f
added test for the fisher_trace calculation
87ea89d
changes for isort
1fad2f2
black formatter changes
83a7289
refining the fisher trace test
68c6ab7
Added some info to a warning
08318ab
Implemented the tensornetwork entropy recording
eb44f81
added different entropy computations for the tensornetwork matrix
40cba18
forgot to stage a change (:
5428639
missspelled correct matrix for the tensornetwork matrix test
bc9d6bc
optimized the tensornetwork matrix algorithm
72a6e69
Take mean over output dimension ntks instead of just the trace
2c28b8f
added docstring
2e69aaa
fixed some testing
f266767
had to change collective variables example because of changes to the
dd49f8b
implemented tracing for all the observables that need a 2x2 ntk
2d602c0
addition to the commit of changing the notebook
f566a2b
numpy is called onp
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
59 changes: 59 additions & 0 deletions
59
CI/unit_tests/observables/test_fisher_trace_calculation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
This module tests the implementation of the fisher trace computation module. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from znnl.observables.fisher_trace import compute_fisher_trace | ||
|
||
|
||
class TestFisherTrace: | ||
""" | ||
Class for testing the implementation of the fisher trace calculation | ||
""" | ||
|
||
def test_fisher_trace_computation(self): | ||
""" | ||
Function tests if the fisher trace computation works correctly for an | ||
example which was calculated by hand before. | ||
|
||
Returns | ||
------- | ||
Asserts the calculated fisher trace for the manually defined inputs | ||
is what it should be. | ||
""" | ||
|
||
ntk = np.array( | ||
[ | ||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], np.random.rand(3, 3)], | ||
[np.random.rand(3, 3), [[2, 1, 3], [1, 2, 3], [3, 2, 1]]], | ||
] | ||
) | ||
loss_derivative = np.array([[5, 4, 3], [2, 1, 0]]) | ||
|
||
trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) | ||
assert trace == 638 / 2 | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
This module tests the implementation of the tensornetwork matrix computation module. | ||
""" | ||
|
||
import numpy as np | ||
from numpy.testing import assert_almost_equal | ||
|
||
from znnl.observables.tensornetwork_matrix import compute_tensornetwork_matrix | ||
|
||
|
||
class TestTensornetworkMatrix: | ||
""" | ||
Class for testing the implementation of the tensornetwork matrix calculation | ||
""" | ||
|
||
def test_tensornetwork_matrix_computation(self): | ||
""" | ||
Function tests if the fisher trace computation works correctly for an | ||
example which was calculated by hand before. | ||
|
||
Returns | ||
------- | ||
Asserts the calculated fisher trace for the manually defined inputs | ||
is what it should be. | ||
""" | ||
|
||
ntk = np.array([[1, 2, 3, 4], [1, 1, 1, 1], [0, 1, 2, 4], [1, 0, 0, 8]]) | ||
targets = np.array([2, 2, 1, 1]) | ||
|
||
matrix = compute_tensornetwork_matrix(ntk=ntk, targets=targets) | ||
correctmatrix = [[3.5, 0.5], [2.25, 1.25]] | ||
|
||
assert_almost_equal(matrix, correctmatrix) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are all of these notebooks cleared of their output cells? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
Module for the observables. | ||
""" | ||
from znnl.observables.covariance_entropy import compute_covariance_entropy | ||
from znnl.observables.entropy import compute_entropy | ||
from znnl.observables.fisher_trace import compute_fisher_trace | ||
from znnl.observables.magnitude_entropy import compute_magnitude_density | ||
from znnl.observables.tensornetwork_matrix import compute_tensornetwork_matrix | ||
|
||
__all__ = [ | ||
compute_fisher_trace.__name__, | ||
compute_tensornetwork_matrix.__name__, | ||
compute_entropy.__name__, | ||
compute_magnitude_density.__name__, | ||
compute_covariance_entropy.__name__, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
Module for the computation of the matrix covariance entropy. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from znnl.analysis.entropy import EntropyAnalysis | ||
from znnl.utils.matrix_utils import normalize_gram_matrix | ||
|
||
|
||
def compute_covariance_entropy(matrix: np.ndarray): | ||
""" | ||
Function to compute the covariance entropy of a matrix. | ||
|
||
Parameters | ||
---------- | ||
Matrix of which to compute the covariance entropy. | ||
|
||
Returns | ||
------- | ||
Covariance entropy of the matrix.""" | ||
|
||
cov_matrix = normalize_gram_matrix(matrix) | ||
calculator = EntropyAnalysis(matrix=cov_matrix) | ||
entropy = calculator.compute_von_neumann_entropy( | ||
effective=False, normalize_eig=True | ||
) | ||
|
||
return entropy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
ZnNL: A Zincwarecode package. | ||
|
||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
|
||
SPDX-License-Identifier: EPL-2.0 | ||
|
||
Copyright Contributors to the Zincwarecode Project. | ||
|
||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
|
||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
|
||
Summary | ||
------- | ||
Module for the computation of the matrix entropy. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from znnl.analysis.entropy import EntropyAnalysis | ||
|
||
|
||
def compute_entropy(matrix: np.ndarray): | ||
""" | ||
Function to compute the entropy of a matrix. | ||
|
||
Parameters | ||
---------- | ||
Matrix of which to compute the entropy. | ||
|
||
Returns | ||
------- | ||
Entropy of the matrix.""" | ||
calculator = EntropyAnalysis(matrix=matrix) | ||
entropy = calculator.compute_von_neumann_entropy( | ||
effective=False, normalize_eig=True | ||
) | ||
|
||
return entropy |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the 1/2 coming from? Is this just a simple solution to the analytic trace or are you normalising?