Skip to content

Commit 48d22da

Browse files
authored
Merge pull request #125 from basf/develop
include quantile regression
2 parents 967f49f + ccfc75a commit 48d22da

File tree

4 files changed

+54
-4
lines changed

4 files changed

+54
-4
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ MambularLSS allows you to model the full distribution of a response variable, no
200200
- **negativebinom**: For over-dispersed count data.
201201
- **inversegamma**: Often used as a prior in Bayesian inference.
202202
- **categorical**: For data with more than two categories.
203+
- **Quantile**: For quantile regression using the pinball loss.
203204

204205
These distribution classes make MambularLSS versatile in modeling various data types and distributions.
205206

mambular/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.2.2"
4+
__version__ = "0.2.3"

mambular/models/sklearn_base_lss.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
NormalDistribution,
3131
PoissonDistribution,
3232
StudentTDistribution,
33+
Quantile,
3334
)
3435
from lightning.pytorch.callbacks import ModelSummary
3536

@@ -210,11 +211,9 @@ def build_model(
210211
X, y, X_val, y_val, val_size=val_size, random_state=random_state
211212
)
212213

213-
num_classes = len(np.unique(y))
214-
215214
self.task_model = TaskModel(
216215
model_class=self.base_model,
217-
num_classes=num_classes,
216+
num_classes=self.family.param_count,
218217
config=self.config,
219218
cat_feature_info=self.data_module.cat_feature_info,
220219
num_feature_info=self.data_module.num_feature_info,
@@ -347,6 +346,7 @@ def fit(
347346
"negativebinom": NegativeBinomialDistribution,
348347
"inversegamma": InverseGammaDistribution,
349348
"categorical": CategoricalDistribution,
349+
"quantile": Quantile,
350350
}
351351

352352
if distributional_kwargs is None:

mambular/utils/distributions.py

+49
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,52 @@ def compute_loss(self, predictions, y_true):
504504
# Compute the negative log-likelihood
505505
nll = -cat_dist.log_prob(y_true).mean()
506506
return nll
507+
508+
509+
class Quantile(BaseDistribution):
510+
"""
511+
Quantile Regression Loss class.
512+
513+
This class computes the quantile loss (also known as pinball loss) for a set of quantiles.
514+
It is used to handle quantile regression tasks where we aim to predict a given quantile of the target distribution.
515+
516+
Parameters
517+
----------
518+
name : str, optional
519+
The name of the distribution, by default "Quantile".
520+
quantiles : list of float, optional
521+
A list of quantiles to be used for computing the loss, by default [0.25, 0.5, 0.75].
522+
523+
Attributes
524+
----------
525+
quantiles : list of float
526+
List of quantiles for which the pinball loss is computed.
527+
528+
Methods
529+
-------
530+
compute_loss(predictions, y_true)
531+
Computes the quantile regression loss between the predictions and true values.
532+
"""
533+
534+
def __init__(self, name="Quantile", quantiles=[0.25, 0.5, 0.75]):
535+
param_names = [
536+
f"q_{q}" for q in quantiles
537+
] # Use string representations of quantiles
538+
super().__init__(name, param_names)
539+
self.quantiles = quantiles
540+
541+
def compute_loss(self, predictions, y_true):
542+
543+
assert not y_true.requires_grad # Ensure y_true does not require gradients
544+
assert predictions.size(0) == y_true.size(0) # Ensure batch size matches
545+
546+
losses = []
547+
for i, q in enumerate(self.quantiles):
548+
errors = y_true - predictions[:, i] # Calculate errors for each quantile
549+
# Compute the pinball loss
550+
quantile_loss = torch.max((q - 1) * errors, q * errors)
551+
losses.append(quantile_loss)
552+
553+
# Sum losses across quantiles and compute mean
554+
loss = torch.mean(torch.stack(losses, dim=1).sum(dim=1))
555+
return loss

0 commit comments

Comments
 (0)