Skip to content

Commit

Permalink
add gsm evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Dec 23, 2024
1 parent 18e8585 commit dce9b6e
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions gsm_evaluator_with_lora_soup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# here, we train experts and we upload them to a local library (repository) of experts.

import os
from mttl.arguments import ExpertConfig
from mttl.datamodule.base import get_datamodule
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.expert_model import (
ExpertModel,
MultiExpertModel,
MultiExpertModelConfig,
ExpertModelConfig,
)
from mttl.models.train_utils import train_model

from mttl.evaluators.gsm_evaluator import GsmEvaluator
from mttl.evaluators.rouge_evaluator import RougeEvaluator
from mttl.models.containers.selectors.base import UniformSelectorConfig
from mttl.arguments import EvaluationConfig, ExpertConfig
from mttl.models.lightning.expert_module import ExpertModule
import torch
from mttl.logging import setup_logging

device = "cuda" if torch.cuda.is_available() else "cpu"
setup_logging()

args = EvaluationConfig.parse()

datamodule = get_datamodule(args, for_generation=True)
evaluator = GsmEvaluator(datamodule)

#
module = ExpertModule(**vars(args)).to(device)

if args.checkpoint is not None:
checkpoint = torch.load(args.checkpoint, weights_only=False)["state_dict"]
module.load_state_dict(checkpoint)
## evaluate
result = evaluator.evaluate(module.model, split="test")

0 comments on commit dce9b6e

Please sign in to comment.