Skip to content

Log regression task metrics in multitask model #3648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

ntravis22
Copy link
Contributor

@ntravis22 ntravis22 commented Mar 24, 2025

Recently per-task metrics were added to multitask_model, however we did not include any for regression tasks and we did not check that the metric keys are present which can throw an error, so this addresses both of those concerns.

Copy link
Contributor

@MattGPT-ai MattGPT-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there metrics for regression type models that we could put into scores, or do those perhaps go into the "classification report" - e.g. Pearson, spearman. In the regression models, results is written as

            eval_metrics = {
                "loss": eval_loss.item(),
                "mse": metric.mean_squared_error(),
                "mae": metric.mean_absolute_error(),
                "pearson": metric.pearsonr(),
                "spearman": metric.spearmanr(),
            }

So maybe we could either check for the base model class that defines evaluate, or just check for the keys. Then maybe we could write e.g. scores[(task_id, 'mse')]. What do you think?

@ntravis22 ntravis22 changed the title Prevent error if not all metrics are present for a task Log mse for regression tasks (and make sure metrics are present) Mar 25, 2025
@ntravis22
Copy link
Contributor Author

Are there metrics for regression type models that we could put into scores, or do those perhaps go into the "classification report" - e.g. Pearson, spearman. In the regression models, results is written as

            eval_metrics = {
                "loss": eval_loss.item(),
                "mse": metric.mean_squared_error(),
                "mae": metric.mean_absolute_error(),
                "pearson": metric.pearsonr(),
                "spearman": metric.spearmanr(),
            }

So maybe we could either check for the base model class that defines evaluate, or just check for the keys. Then maybe we could write e.g. scores[(task_id, 'mse')]. What do you think?

Ok I added mse.

@MattGPT-ai
Copy link
Contributor

Oh actually, can we just add all four of the metrics?

@ntravis22 ntravis22 changed the title Log mse for regression tasks (and make sure metrics are present) Log regression task metrics in multitask model Mar 25, 2025
@ntravis22
Copy link
Contributor Author

Oh actually, can we just add all four of the metrics?

Done

@alanakbik
Copy link
Collaborator

@ntravis22 @MattGPT-ai Could you paste a script to test this PR?

@ntravis22
Copy link
Contributor Author

ntravis22 commented Apr 1, 2025

@alanakbik Here is a script to test:

from flair.data import Sentence, Corpus
from flair.embeddings import TransformerEmbeddings
from flair.models import TextRegressor
from flair.trainers import ModelTrainer
from flair.trainers.plugins.base import TrainerPlugin
from flair.nn.multitask import make_multitask_model_and_corpus


class MetricPlugin(TrainerPlugin):
    """Test plugin. In practice this could do something like logging metrics to WandB."""

    @TrainerPlugin.hook
    def metric_recorded(self, record):
        print(f"Metric: {record}")

sentences = [Sentence("This is a sentence") for _ in range(100)]
for sentence in sentences:
    sentence.add_label("regression_label", 1.0)
corpus = Corpus(sentences)
model = TextRegressor(TransformerEmbeddings('bert-base-uncased'), label_name="regression_label")

multitask_model, multicorpus = make_multitask_model_and_corpus([(model, corpus)])

trainer = ModelTrainer(multitask_model, multicorpus)
trainer.train('regression_model/', plugins=[MetricPlugin()])

Running this with the changes in this PR you can see lines printed like:

Metric: MetricRecord(dev/Task_0/mse at step 6, 1743547102.8555)
Metric: MetricRecord(dev/Task_0/mae at step 6, 1743547102.8556)
Metric: MetricRecord(dev/Task_0/pearson at step 6, 1743547102.8556)
Metric: MetricRecord(dev/Task_0/spearman at step 6, 1743547102.8556)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants