-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
18e8585
commit dce9b6e
Showing
1 changed file
with
38 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# here, we train experts and we upload them to a local library (repository) of experts. | ||
|
||
import os | ||
from mttl.arguments import ExpertConfig | ||
from mttl.datamodule.base import get_datamodule | ||
from mttl.models.library.expert_library import ExpertLibrary | ||
from mttl.models.expert_model import ( | ||
ExpertModel, | ||
MultiExpertModel, | ||
MultiExpertModelConfig, | ||
ExpertModelConfig, | ||
) | ||
from mttl.models.train_utils import train_model | ||
|
||
from mttl.evaluators.gsm_evaluator import GsmEvaluator | ||
from mttl.evaluators.rouge_evaluator import RougeEvaluator | ||
from mttl.models.containers.selectors.base import UniformSelectorConfig | ||
from mttl.arguments import EvaluationConfig, ExpertConfig | ||
from mttl.models.lightning.expert_module import ExpertModule | ||
import torch | ||
from mttl.logging import setup_logging | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
setup_logging() | ||
|
||
args = EvaluationConfig.parse() | ||
|
||
datamodule = get_datamodule(args, for_generation=True) | ||
evaluator = GsmEvaluator(datamodule) | ||
|
||
# | ||
module = ExpertModule(**vars(args)).to(device) | ||
|
||
if args.checkpoint is not None: | ||
checkpoint = torch.load(args.checkpoint, weights_only=False)["state_dict"] | ||
module.load_state_dict(checkpoint) | ||
## evaluate | ||
result = evaluator.evaluate(module.model, split="test") |