-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathloss.py
33 lines (26 loc) · 1.19 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
from torch.nn import functional as F
def get_loss_module(task_name):
if (task_name == "imputation") or (task_name == "transduction"):
return MaskedMSELoss(reduction='none') # outputs loss for each batch element
if task_name == "regression":
return nn.MSELoss(reduction='none') # outputs loss for each batch sample
else:
raise ValueError("Loss module for task '{}' does not exist".format(task_name))
def l2_reg_loss(model):
"""Returns the squared L2 norm of output layer of given model"""
for name, param in model.named_parameters():
if name == 'output_layer.weight':
return torch.sum(torch.square(param))
class MaskedMSELoss(nn.Module):
""" Masked MSE Loss
"""
def __init__(self, reduction: str = 'mean'):
super().__init__()
self.reduction = reduction
self.mse_loss = nn.MSELoss(reduction=self.reduction)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
masked_pred = torch.masked_select(y_pred, mask)
masked_true = torch.masked_select(y_true, mask)
return self.mse_loss(masked_pred, masked_true)