@@ -504,3 +504,52 @@ def compute_loss(self, predictions, y_true):
504
504
# Compute the negative log-likelihood
505
505
nll = - cat_dist .log_prob (y_true ).mean ()
506
506
return nll
507
+
508
+
509
+ class Quantile (BaseDistribution ):
510
+ """
511
+ Quantile Regression Loss class.
512
+
513
+ This class computes the quantile loss (also known as pinball loss) for a set of quantiles.
514
+ It is used to handle quantile regression tasks where we aim to predict a given quantile of the target distribution.
515
+
516
+ Parameters
517
+ ----------
518
+ name : str, optional
519
+ The name of the distribution, by default "Quantile".
520
+ quantiles : list of float, optional
521
+ A list of quantiles to be used for computing the loss, by default [0.25, 0.5, 0.75].
522
+
523
+ Attributes
524
+ ----------
525
+ quantiles : list of float
526
+ List of quantiles for which the pinball loss is computed.
527
+
528
+ Methods
529
+ -------
530
+ compute_loss(predictions, y_true)
531
+ Computes the quantile regression loss between the predictions and true values.
532
+ """
533
+
534
+ def __init__ (self , name = "Quantile" , quantiles = [0.25 , 0.5 , 0.75 ]):
535
+ param_names = [
536
+ f"q_{ q } " for q in quantiles
537
+ ] # Use string representations of quantiles
538
+ super ().__init__ (name , param_names )
539
+ self .quantiles = quantiles
540
+
541
+ def compute_loss (self , predictions , y_true ):
542
+
543
+ assert not y_true .requires_grad # Ensure y_true does not require gradients
544
+ assert predictions .size (0 ) == y_true .size (0 ) # Ensure batch size matches
545
+
546
+ losses = []
547
+ for i , q in enumerate (self .quantiles ):
548
+ errors = y_true - predictions [:, i ] # Calculate errors for each quantile
549
+ # Compute the pinball loss
550
+ quantile_loss = torch .max ((q - 1 ) * errors , q * errors )
551
+ losses .append (quantile_loss )
552
+
553
+ # Sum losses across quantiles and compute mean
554
+ loss = torch .mean (torch .stack (losses , dim = 1 ).sum (dim = 1 ))
555
+ return loss
0 commit comments