Skip to content

Commit 81c6b17

Browse files
xgfstensorflower-gardener
authored andcommitted
Update the SVD metrics interface to support custom getters.
PiperOrigin-RevId: 580626182
1 parent d95d12a commit 81c6b17

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

tensorflow_gnn/models/contrastive_losses/metrics.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Metrics for unsupervised embedding evaluation."""
1616
from __future__ import annotations
1717

18-
from collections.abc import Mapping
18+
from collections.abc import Callable, Mapping
1919
from typing import Optional, Protocol
2020

2121
import tensorflow as tf
@@ -294,6 +294,9 @@ class _SvdMetrics(tf.keras.metrics.Metric):
294294
def __init__(
295295
self,
296296
fns: Mapping[str, _SvdProtocol],
297+
y_pred_transform_fn: Optional[
298+
Callable[[tf.Tensor], tf.Tensor]
299+
] = None,
297300
name: str = "svd_metrics",
298301
):
299302
"""Constructs the `tf.keras.metrics.Metric` that reuses SVD decomposition.
@@ -302,27 +305,30 @@ def __init__(
302305
fns: a mapping from a metric name to a `Callable` that accepts
303306
representations as well as the result of their SVD decomposition.
304307
Currently only singular values are passed.
308+
y_pred_transform_fn: a function to extract clean representations
309+
from model predictions. By default, no transformation is applied.
305310
name: Name for the metric class, used for Keras bookkeeping.
306311
"""
307312
super().__init__(name=name)
308313
self._fns = fns
309314
self._metric_container = {
310315
k: tf.keras.metrics.Mean(name=k) for k in fns.keys()
311316
}
317+
if not y_pred_transform_fn:
318+
y_pred_transform_fn = lambda x: x
319+
self._y_pred_transform_fn = y_pred_transform_fn
312320

313321
def reset_state(self) -> None:
314322
for v in self._metric_container.values():
315323
v.reset_state()
316324

317325
def update_state(self, _, y_pred: tf.Tensor, sample_weight=None) -> None:
318-
# In our implementation of contrastive learning, y_pred is a tensor with
319-
# clean and corrupted representations stacked in the first dimension.
320-
representations_clean, _ = tf.unstack(y_pred, axis=1)
326+
representations = self._y_pred_transform_fn(y_pred)
321327
sigma, u, _ = tf.linalg.svd(
322-
representations_clean, compute_uv=True, full_matrices=False
328+
representations, compute_uv=True, full_matrices=False
323329
)
324330
for k, v in self._metric_container.items():
325-
v.update_state(self._fns[k](representations_clean, sigma=sigma, u=u))
331+
v.update_state(self._fns[k](representations, sigma=sigma, u=u))
326332

327333
def result(self) -> Mapping[str, tf.Tensor]:
328334
return {k: v.result() for k, v in self._metric_container.items()}

tensorflow_gnn/models/contrastive_losses/metrics_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_triplet_metrics(self):
159159
self.assertAllClose(actual["triplet_distance"], -1.75)
160160

161161
def test_svd_metrics(self):
162-
tensor = tf.ones((2, 2, 2))
162+
tensor = tf.ones((2, 2))
163163
metric_object = metrics.AllSvdMetrics()
164164
metric_object.update_state(None, tensor)
165165
result = metric_object.result()

tensorflow_gnn/models/contrastive_losses/tasks.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
GraphTensor = tfgnn.GraphTensor
3333

3434

35+
# In our implementation of contrastive learning, `y_pred`` is a tensor with
36+
# clean and corrupted representations stacked in the first dimension.
37+
_UNSTACK_FN = lambda y_pred: tf.unstack(y_pred, axis=1)[0]
38+
39+
3540
class ContrastiveLossTask(runner.Task):
3641
"""Base class for unsupervised contrastive representation learning tasks.
3742
@@ -226,7 +231,9 @@ def metrics(self) -> runner.Metrics:
226231
tf.keras.metrics.BinaryCrossentropy(from_logits=True),
227232
tf.keras.metrics.BinaryAccuracy(),
228233
),
229-
"representations": (metrics.AllSvdMetrics(),),
234+
"representations": (
235+
metrics.AllSvdMetrics(y_pred_transform_fn=_UNSTACK_FN),
236+
),
230237
}
231238

232239

@@ -271,7 +278,7 @@ def loss_fn(_, x):
271278
return loss_fn
272279

273280
def metrics(self) -> runner.Metrics:
274-
return (metrics.AllSvdMetrics(),)
281+
return (metrics.AllSvdMetrics(y_pred_transform_fn=_UNSTACK_FN),)
275282

276283

277284
class VicRegTask(ContrastiveLossTask):
@@ -305,7 +312,7 @@ def loss_fn(_, x):
305312
return loss_fn
306313

307314
def metrics(self) -> runner.Metrics:
308-
return (metrics.AllSvdMetrics(),)
315+
return (metrics.AllSvdMetrics(y_pred_transform_fn=_UNSTACK_FN),)
309316

310317

311318
class TripletLossTask(ContrastiveLossTask):

0 commit comments

Comments
 (0)