|
| 1 | +import importlib |
1 | 2 | import os
|
2 | 3 | import time
|
| 4 | +from pathlib import Path |
| 5 | +from typing import Union |
3 | 6 |
|
4 | 7 | import torch
|
5 | 8 | import tqdm
|
| 9 | +import yaml |
6 | 10 |
|
7 | 11 | from chebifier.check_env import check_package_installed
|
8 | 12 | from chebifier.hugging_face import download_model_files
|
9 | 13 | from chebifier.inconsistency_resolution import PredictionSmoother
|
10 | 14 | from chebifier.prediction_models.base_predictor import BasePredictor
|
11 |
| -from chebifier.utils import get_disjoint_files, load_chebi_graph |
| 15 | +from chebifier.utils import ( |
| 16 | + get_default_configs, |
| 17 | + get_disjoint_files, |
| 18 | + load_chebi_graph, |
| 19 | + process_config, |
| 20 | +) |
12 | 21 |
|
13 | 22 |
|
14 | 23 | class BaseEnsemble:
|
15 | 24 | def __init__(
|
16 | 25 | self,
|
17 |
| - model_configs: dict, |
| 26 | + model_configs: Union[str, Path, dict, None] = None, |
18 | 27 | chebi_version: int = 241,
|
19 | 28 | resolve_inconsistencies: bool = True,
|
20 | 29 | ):
|
21 | 30 | # Deferred Import: To avoid circular import error
|
22 | 31 | from chebifier.model_registry import MODEL_TYPES
|
23 | 32 |
|
| 33 | + # Load configuration from YAML file |
| 34 | + if not model_configs: |
| 35 | + config = get_default_configs() |
| 36 | + elif isinstance(model_configs, dict): |
| 37 | + config = model_configs |
| 38 | + else: |
| 39 | + print(f"Loading ensemble configuration from {model_configs}") |
| 40 | + with open(model_configs, "r") as f: |
| 41 | + config = yaml.safe_load(f) |
| 42 | + |
| 43 | + with ( |
| 44 | + importlib.resources.files("chebifier") |
| 45 | + .joinpath("model_registry.yml") |
| 46 | + .open("r") as f |
| 47 | + ): |
| 48 | + model_registry = yaml.safe_load(f) |
| 49 | + |
| 50 | + processed_configs = process_config(config, model_registry) |
| 51 | + |
24 | 52 | self.chebi_graph = load_chebi_graph()
|
25 | 53 | self.disjoint_files = get_disjoint_files()
|
26 | 54 |
|
27 | 55 | self.models = []
|
28 | 56 | self.positive_prediction_threshold = 0.5
|
29 |
| - for model_name, model_config in model_configs.items(): |
| 57 | + for model_name, model_config in processed_configs.items(): |
30 | 58 | model_cls = MODEL_TYPES[model_config["type"]]
|
31 | 59 | if "hugging_face" in model_config:
|
32 | 60 | hugging_face_kwargs = download_model_files(model_config["hugging_face"])
|
|
0 commit comments