Skip to content

Commit 7aaf46f

Browse files
authored
Merge pull request #16 from cthoyt/simplify-getting-started
update constructor of base ensemble to run w/o arguments (using same default as CLI), expose BaseEnsemble
2 parents ccbeb74 + edcf93e commit 7aaf46f

File tree

10 files changed

+86
-63
lines changed

10 files changed

+86
-63
lines changed

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,11 @@ python -m chebifier predict --help
7575
You can also use the package programmatically:
7676

7777
```python
78-
from chebifier.ensemble.base_ensemble import BaseEnsemble
79-
import yaml
78+
from chebifier import BaseEnsemble
8079
81-
# Load configuration from YAML file
82-
with open('configs/example_config.yml', 'r') as f:
83-
config = yaml.safe_load(f)
84-
85-
# Instantiate ensemble model
86-
ensemble = BaseEnsemble(config)
80+
# Instantiate ensemble model. If desired, can pass
81+
# a path to a configuration, like 'configs/example_config.yml'
82+
ensemble = BaseEnsemble()
8783
8884
# Make predictions
8985
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]

chebifier/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Note: The top-level package __init__.py runs only once,
22
# even if multiple subpackages are imported later.
33

4-
from ._custom_cache import PerSmilesPerModelLRUCache
4+
from ._custom_cache import PerSmilesPerModelLRUCache, modelwise_smiles_lru_cache
5+
from .ensemble.base_ensemble import BaseEnsemble
56

6-
modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)
7+
__all__ = [
8+
"BaseEnsemble",
9+
"PerSmilesPerModelLRUCache",
10+
"modelwise_smiles_lru_cache",
11+
]

chebifier/_custom_cache.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
from functools import wraps
77
from typing import Any, Callable
88

9+
__all__ = [
10+
"PerSmilesPerModelLRUCache",
11+
"modelwise_smiles_lru_cache",
12+
]
13+
914

1015
class PerSmilesPerModelLRUCache:
1116
"""
@@ -206,3 +211,6 @@ def _load_cache(self) -> None:
206211
self._cache = loaded
207212
except Exception as e:
208213
print(f"[Cache Load Error] {e}")
214+
215+
216+
modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)

chebifier/cli.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import importlib.resources
2-
31
import click
4-
import yaml
52

63
from chebifier.model_registry import ENSEMBLES
74

@@ -72,43 +69,10 @@ def predict(
7269
resolve_inconsistencies=True,
7370
):
7471
"""Predict ChEBI classes for SMILES strings using an ensemble model."""
75-
# Load configuration from YAML file
76-
if not ensemble_config:
77-
print("Using default ensemble configuration")
78-
with (
79-
importlib.resources.files("chebifier")
80-
.joinpath("ensemble.yml")
81-
.open("r") as f
82-
):
83-
config = yaml.safe_load(f)
84-
else:
85-
print(f"Loading ensemble configuration from {ensemble_config}")
86-
with open(ensemble_config, "r") as f:
87-
config = yaml.safe_load(f)
88-
89-
with (
90-
importlib.resources.files("chebifier")
91-
.joinpath("model_registry.yml")
92-
.open("r") as f
93-
):
94-
model_registry = yaml.safe_load(f)
95-
96-
new_config = {}
97-
for model_name, entry in config.items():
98-
if "load_model" in entry:
99-
if entry["load_model"] not in model_registry:
100-
raise ValueError(
101-
f"Model {entry['load_model']} not found in model registry. "
102-
f"Available models are: {','.join(model_registry.keys())}."
103-
)
104-
new_config[model_name] = {**model_registry[entry["load_model"]], **entry}
105-
else:
106-
new_config[model_name] = entry
107-
config = new_config
10872

10973
# Instantiate ensemble model
11074
ensemble = ENSEMBLES[ensemble_type](
111-
config,
75+
ensemble_config,
11276
chebi_version=chebi_version,
11377
resolve_inconsistencies=resolve_inconsistencies,
11478
)

chebifier/ensemble/base_ensemble.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,60 @@
1+
import importlib
12
import os
23
import time
4+
from pathlib import Path
5+
from typing import Union
36

47
import torch
58
import tqdm
9+
import yaml
610

711
from chebifier.check_env import check_package_installed
812
from chebifier.hugging_face import download_model_files
913
from chebifier.inconsistency_resolution import PredictionSmoother
1014
from chebifier.prediction_models.base_predictor import BasePredictor
11-
from chebifier.utils import get_disjoint_files, load_chebi_graph
15+
from chebifier.utils import (
16+
get_default_configs,
17+
get_disjoint_files,
18+
load_chebi_graph,
19+
process_config,
20+
)
1221

1322

1423
class BaseEnsemble:
1524
def __init__(
1625
self,
17-
model_configs: dict,
26+
model_configs: Union[str, Path, dict, None] = None,
1827
chebi_version: int = 241,
1928
resolve_inconsistencies: bool = True,
2029
):
2130
# Deferred Import: To avoid circular import error
2231
from chebifier.model_registry import MODEL_TYPES
2332

33+
# Load configuration from YAML file
34+
if not model_configs:
35+
config = get_default_configs()
36+
elif isinstance(model_configs, dict):
37+
config = model_configs
38+
else:
39+
print(f"Loading ensemble configuration from {model_configs}")
40+
with open(model_configs, "r") as f:
41+
config = yaml.safe_load(f)
42+
43+
with (
44+
importlib.resources.files("chebifier")
45+
.joinpath("model_registry.yml")
46+
.open("r") as f
47+
):
48+
model_registry = yaml.safe_load(f)
49+
50+
processed_configs = process_config(config, model_registry)
51+
2452
self.chebi_graph = load_chebi_graph()
2553
self.disjoint_files = get_disjoint_files()
2654

2755
self.models = []
2856
self.positive_prediction_threshold = 0.5
29-
for model_name, model_config in model_configs.items():
57+
for model_name, model_config in processed_configs.items():
3058
model_cls = MODEL_TYPES[model_config["type"]]
3159
if "hugging_face" in model_config:
3260
hugging_face_kwargs = download_model_files(model_config["hugging_face"])

chebifier/inconsistency_resolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import csv
22
import os
3-
import torch
43
from pathlib import Path
54

5+
import torch
6+
67

78
def get_disjoint_groups(disjoint_files):
89
if disjoint_files is None:

chebifier/model_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
WMVwithPPVNPVEnsemble,
55
)
66
from chebifier.prediction_models import (
7+
ChEBILookupPredictor,
78
ChemlogPeptidesPredictor,
89
ElectraPredictor,
910
ResGatedPredictor,
10-
ChEBILookupPredictor,
1111
)
1212
from chebifier.prediction_models.c3p_predictor import C3PPredictor
1313
from chebifier.prediction_models.chemlog_predictor import (
14-
ChemlogXMolecularEntityPredictor,
1514
ChemlogOrganoXCompoundPredictor,
15+
ChemlogXMolecularEntityPredictor,
1616
)
1717

1818
ENSEMBLES = {

chebifier/prediction_models/base_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from abc import ABC
33

4-
from chebifier import modelwise_smiles_lru_cache
4+
from .._custom_cache import modelwise_smiles_lru_cache
55

66

77
class BasePredictor(ABC):

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import tqdm
44

5-
from .base_predictor import BasePredictor
65
from .. import modelwise_smiles_lru_cache
6+
from .base_predictor import BasePredictor
77

88
AA_DICT = {
99
"A": "L-alanine",

chebifier/utils.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import importlib.resources
12
import os
3+
import pickle
24

5+
import fastobo
36
import networkx as nx
47
import requests
5-
import fastobo
8+
import yaml
9+
610
from chebifier.hugging_face import download_model_files
7-
import pickle
811

912

1013
def load_chebi_graph(filename=None):
@@ -123,9 +126,27 @@ def get_disjoint_files():
123126
return disjoint_files
124127

125128

126-
if __name__ == "__main__":
127-
# chebi_graph = build_chebi_graph(chebi_version=241)
128-
# save the graph to a file
129-
# pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
130-
chebi_graph = load_chebi_graph()
131-
print(chebi_graph)
129+
def get_default_configs():
130+
default_config_name = "ensemble.yml"
131+
print(f"Using default ensemble configuration from {default_config_name}")
132+
with (
133+
importlib.resources.files("chebifier")
134+
.joinpath(default_config_name)
135+
.open("r") as f
136+
):
137+
return yaml.safe_load(f)
138+
139+
140+
def process_config(config, model_registry):
141+
new_config = {}
142+
for model_name, entry in config.items():
143+
if "load_model" in entry:
144+
if entry["load_model"] not in model_registry:
145+
raise ValueError(
146+
f"Model {entry['load_model']} not found in model registry. "
147+
f"Available models are: {','.join(model_registry.keys())}."
148+
)
149+
new_config[model_name] = {**model_registry[entry["load_model"]], **entry}
150+
else:
151+
new_config[model_name] = entry
152+
return new_config

0 commit comments

Comments
 (0)