Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 7ef9d8e

Browse files
authored
Add FID (#40)
* [WIP] Add FID * [WIP] Add FID * Add reference copyright information * Refactor FID metric * Add medicalnet feature extractor * Remove feature extractors from implementation * Add tests Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
1 parent 4a72ccc commit 7ef9d8e

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed

generative/metrics/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from .fid import FID
1213
from .mmd import MMD
1314
from .ms_ssim import MSSSIM

generative/metrics/fid.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
# =========================================================================
13+
# Adapted from https://github.com/photosynthesis-team/piq
14+
# which has the following license:
15+
# https://github.com/photosynthesis-team/piq/blob/master/LICENSE
16+
17+
# Copyright 2023 photosynthesis-team. All rights reserved.
18+
#
19+
# Licensed under the Apache License, Version 2.0 (the "License");
20+
# you may not use this file except in compliance with the License.
21+
# You may obtain a copy of the License at
22+
#
23+
# http://www.apache.org/licenses/LICENSE-2.0
24+
#
25+
# Unless required by applicable law or agreed to in writing, software
26+
# distributed under the License is distributed on an "AS IS" BASIS,
27+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28+
# See the License for the specific language governing permissions and
29+
# limitations under the License.
30+
# =========================================================================
31+
32+
from __future__ import annotations
33+
34+
import torch
35+
from monai.metrics.metric import Metric
36+
37+
38+
class FID(Metric):
39+
"""
40+
Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors.
41+
Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium."
42+
https://arxiv.org/abs/1706.08500#. The inputs for this metric should be two groups of feature vectors (with format
43+
(number images, number of features)) extracted from the a pretrained network.
44+
45+
Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet.
46+
However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and
47+
MedicalNet for 3D images). If the chosen model output is not a scalar, usually it is used a global spatial
48+
average pooling.
49+
"""
50+
51+
def __init__(self) -> None:
52+
super().__init__()
53+
54+
def __call__(self, y_pred: torch.Tensor, y: torch.Tensor):
55+
return get_fid_score(y_pred, y)
56+
57+
58+
def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
59+
y = y.float()
60+
y_pred = y_pred.float()
61+
62+
if y.ndimension() > 2:
63+
raise ValueError(f"Inputs should have (number images, number of features) shape.")
64+
65+
mu_y_pred = torch.mean(y_pred, dim=0)
66+
sigma_y_pred = _cov(y_pred, rowvar=False)
67+
mu_y = torch.mean(y, dim=0)
68+
sigma_y = _cov(y, rowvar=False)
69+
70+
return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y)
71+
72+
73+
def _cov(m: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
74+
"""
75+
Estimate a covariance matrix of the variables.
76+
77+
Args:
78+
m: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
79+
and each column a single observation of all those variables.
80+
rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.
81+
Otherwise, the relationship is transposed: each column represents a variable, while the rows contain
82+
observations.
83+
"""
84+
if m.dim() < 2:
85+
m = m.view(1, -1)
86+
87+
if not rowvar and m.size(0) != 1:
88+
m = m.t()
89+
90+
fact = 1.0 / (m.size(1) - 1)
91+
m = m - torch.mean(m, dim=1, keepdim=True)
92+
mt = m.t()
93+
return fact * m.matmul(mt).squeeze()
94+
95+
96+
def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]:
97+
"""
98+
Square root of matrix using Newton-Schulz Iterative method. Based on:
99+
https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py. Bechmark shown in:
100+
https://github.com/photosynthesis-team/piq/issues/190#issuecomment-742039303
101+
102+
Args:
103+
matrix: matrix or batch of matrices
104+
num_iters: Number of iteration of the method
105+
106+
"""
107+
dim = matrix.size(0)
108+
norm_of_matrix = matrix.norm(p="fro")
109+
y_matrix = matrix.div(norm_of_matrix)
110+
i_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)
111+
z_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)
112+
113+
s_matrix = torch.empty_like(matrix)
114+
error = torch.empty(1, device=matrix.device, dtype=matrix.dtype)
115+
116+
for _ in range(num_iters):
117+
T = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
118+
y_matrix = y_matrix.mm(T)
119+
z_matrix = T.mm(z_matrix)
120+
121+
s_matrix = y_matrix * torch.sqrt(norm_of_matrix)
122+
123+
norm_of_matrix = torch.norm(matrix)
124+
error = matrix - torch.mm(s_matrix, s_matrix)
125+
error = torch.norm(error) / norm_of_matrix
126+
127+
if torch.isclose(error, torch.tensor([0.0], device=error.device, dtype=error.dtype), atol=1e-5):
128+
break
129+
130+
return s_matrix, error
131+
132+
133+
def compute_frechet_distance(
134+
mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6
135+
) -> torch.Tensor:
136+
"""The Frechet distance between multivariate normal distributions."""
137+
diff = mu_x - mu_y
138+
covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y))
139+
140+
# If calculation produces singular product, epsilon is added to diagonal of cov estimates
141+
if not torch.isfinite(covmean).all():
142+
offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon
143+
covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset))
144+
145+
tr_covmean = torch.trace(covmean)
146+
return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean

tests/test_compute_fid_metric.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
import unittest
14+
15+
import numpy as np
16+
import torch
17+
18+
from generative.metrics import FID
19+
20+
21+
class TestMMDMetric(unittest.TestCase):
22+
def test_results(self):
23+
x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
24+
y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
25+
results = FID()(x, y)
26+
np.testing.assert_allclose(results.cpu().numpy(), 0.4433, atol=1e-4)
27+
28+
def test_input_dimensions(self):
29+
with self.assertRaises(ValueError):
30+
FID()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
31+
32+
33+
if __name__ == "__main__":
34+
unittest.main()

0 commit comments

Comments
 (0)