Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marc record tensornetwork entropy #94

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d253449
changed loss derivative to be a vector
May 22, 2023
5e3f8f5
Implemented the recording of the fisher trace
May 23, 2023
928fbc7
Implemented the recording of the fisher trace
May 23, 2023
ffbacbf
fixed a commit issue after ammending to a already pushed branch
May 23, 2023
1f6888d
fixing black formatting
May 23, 2023
1b386ab
somehow i did every change twice, this commit removed the doubles
May 23, 2023
e553e68
changes
SamTov May 23, 2023
1984382
optimized fisher trace calculation
May 23, 2023
fac41d8
black formatting again
May 23, 2023
f3c64b5
and again
May 23, 2023
49a33e0
Move fisher computation to own module.
SamTov May 23, 2023
14a8e7f
added test for the fisher_trace calculation
May 23, 2023
87ea89d
changes for isort
May 23, 2023
1fad2f2
black formatter changes
May 23, 2023
83a7289
refining the fisher trace test
May 23, 2023
68c6ab7
Added some info to a warning
May 23, 2023
08318ab
Implemented the tensornetwork entropy recording
May 23, 2023
eb44f81
added different entropy computations for the tensornetwork matrix
May 24, 2023
40cba18
forgot to stage a change (:
May 24, 2023
5428639
missspelled correct matrix for the tensornetwork matrix test
May 24, 2023
bc9d6bc
optimized the tensornetwork matrix algorithm
May 24, 2023
72a6e69
Take mean over output dimension ntks instead of just the trace
May 24, 2023
2c28b8f
added docstring
May 24, 2023
2e69aaa
fixed some testing
May 24, 2023
f266767
had to change collective variables example because of changes to the
May 25, 2023
dd49f8b
implemented tracing for all the observables that need a 2x2 ntk
May 25, 2023
2d602c0
addition to the commit of changing the notebook
May 25, 2023
f566a2b
numpy is called onp
May 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions CI/unit_tests/observables/test_fisher_trace_calculation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
This module tests the implementation of the fisher trace computation module.
"""

import numpy as np

from znnl.observables.fisher_trace import compute_fisher_trace


class TestFisherTrace:
"""
Class for testing the implementation of the fisher trace calculation
"""

def test_fisher_trace_computation(self):
"""
Function tests if the fisher trace computation works correctly for an
example which was calculated by hand before.

Returns
-------
Asserts the calculated fisher trace for the manually defined inputs
is what it should be.
"""

ntk = np.array(
[
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], np.random.rand(3, 3)],
[np.random.rand(3, 3), [[2, 1, 3], [1, 2, 3], [3, 2, 1]]],
]
)
loss_derivative = np.array([[5, 4, 3], [2, 1, 0]])

trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk)
assert trace == 638 / 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the 1/2 coming from? Is this just a simple solution to the analytic trace or are you normalising?

57 changes: 57 additions & 0 deletions CI/unit_tests/observables/test_tensornetwork_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
This module tests the implementation of the tensornetwork matrix computation module.
"""

import numpy as np
from numpy.testing import assert_almost_equal

from znnl.observables.tensornetwork_matrix import compute_tensornetwork_matrix


class TestTensornetworkMatrix:
"""
Class for testing the implementation of the tensornetwork matrix calculation
"""

def test_tensornetwork_matrix_computation(self):
"""
Function tests if the fisher trace computation works correctly for an
example which was calculated by hand before.

Returns
-------
Asserts the calculated fisher trace for the manually defined inputs
is what it should be.
"""

ntk = np.array([[1, 2, 3, 4], [1, 1, 1, 1], [0, 1, 2, 4], [1, 0, 0, 8]])
targets = np.array([2, 2, 1, 1])

matrix = compute_tensornetwork_matrix(ntk=ntk, targets=targets)
correctmatrix = [[3.5, 0.5], [2.25, 1.25]]

assert_almost_equal(matrix, correctmatrix)
4 changes: 4 additions & 0 deletions CI/unit_tests/training_recording/test_training_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def test_instantiation(self):
eigenvalues=True,
trace=True,
loss_derivative=True,
fisher_trace=True,
tensornetwork_entropy=True,
tensornetwork_covariance_entropy=True,
tensornetwork_magnitude_entropy=True,
)
recorder.instantiate_recorder(data_set=self.dummy_data_set)
_exclude_list = [
Expand Down
12 changes: 11 additions & 1 deletion examples/Computing-Collective-Variables.ipynb
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all of these notebooks cleared of their output cells?

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "1155f508",
"metadata": {},
Expand All @@ -11,6 +12,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c2f5d943",
"metadata": {},
Expand All @@ -33,6 +35,7 @@
"os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n",
"\n",
"import znnl as nl\n",
"from znnl.utils.matrix_utils import calculate_l_pq_norm\n",
"from neural_tangents import stax\n",
"import optax\n",
"\n",
Expand All @@ -45,6 +48,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "9d379c99",
"metadata": {},
Expand Down Expand Up @@ -111,6 +115,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "38fa1b9d",
"metadata": {
Expand Down Expand Up @@ -154,6 +159,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a6a9fbe0-def2-4bab-a808-8858ab2aa5e9",
"metadata": {
Expand Down Expand Up @@ -207,6 +213,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "612896b2-e9e2-4917-b01b-1ff56da2b142",
"metadata": {
Expand All @@ -233,6 +240,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "38bda516-9feb-4ab1-b3d3-3a552b46e28c",
"metadata": {
Expand All @@ -259,6 +267,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "99646c3f-95ed-4c4e-b01e-5646b3767ff4",
"metadata": {},
Expand Down Expand Up @@ -331,8 +340,9 @@
"metadata": {},
"outputs": [],
"source": [
"loss_derivative = calculate_l_pq_norm(train_report.loss_derivative)\n",
"plt.plot(\n",
" train_report.loss_derivative, \n",
" loss_derivative, \n",
" 'o', \n",
" mfc='None', \n",
" label=r\"$|| \\frac{\\partial L }{ \\partial f}||_R$\"\n",
Expand Down
40 changes: 40 additions & 0 deletions znnl/observables/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the observables.
"""
from znnl.observables.covariance_entropy import compute_covariance_entropy
from znnl.observables.entropy import compute_entropy
from znnl.observables.fisher_trace import compute_fisher_trace
from znnl.observables.magnitude_entropy import compute_magnitude_density
from znnl.observables.tensornetwork_matrix import compute_tensornetwork_matrix

__all__ = [
compute_fisher_trace.__name__,
compute_tensornetwork_matrix.__name__,
compute_entropy.__name__,
compute_magnitude_density.__name__,
compute_covariance_entropy.__name__,
]
53 changes: 53 additions & 0 deletions znnl/observables/covariance_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the computation of the matrix covariance entropy.
"""

import numpy as np

from znnl.analysis.entropy import EntropyAnalysis
from znnl.utils.matrix_utils import normalize_gram_matrix


def compute_covariance_entropy(matrix: np.ndarray):
"""
Function to compute the covariance entropy of a matrix.

Parameters
----------
Matrix of which to compute the covariance entropy.

Returns
-------
Covariance entropy of the matrix."""

cov_matrix = normalize_gram_matrix(matrix)
calculator = EntropyAnalysis(matrix=cov_matrix)
entropy = calculator.compute_von_neumann_entropy(
effective=False, normalize_eig=True
)

return entropy
50 changes: 50 additions & 0 deletions znnl/observables/entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the computation of the matrix entropy.
"""

import numpy as np

from znnl.analysis.entropy import EntropyAnalysis


def compute_entropy(matrix: np.ndarray):
"""
Function to compute the entropy of a matrix.

Parameters
----------
Matrix of which to compute the entropy.

Returns
-------
Entropy of the matrix."""
calculator = EntropyAnalysis(matrix=matrix)
entropy = calculator.compute_von_neumann_entropy(
effective=False, normalize_eig=True
)

return entropy
Loading