Skip to content

Commit d6fd0ae

Browse files
authored
Merge pull request #117 from ardunn/master
Adding top level class + bugfixes
2 parents a64d1ff + 6e710a6 commit d6fd0ae

File tree

10 files changed

+434
-61
lines changed

10 files changed

+434
-61
lines changed

matbench/automl/adaptors.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from matbench.automl.tpot_configs.classifier import classifier_config_dict_mb
1515
from matbench.automl.tpot_configs.regressor import regressor_config_dict_mb
16-
from matbench.utils.utils import is_greater_better, MatbenchError
16+
from matbench.utils.utils import is_greater_better, MatbenchError, set_fitted, check_fitted
1717
from matbench.base import AutoMLAdaptor, LoggableMixin
1818

1919
__authors__ = ['Alex Dunn <[email protected]'
@@ -92,7 +92,7 @@ def __init__(self, mode, logger=True, **tpot_kwargs):
9292
self.is_fit = False
9393
self.random_state = tpot_kwargs.get('random_state', None)
9494

95-
95+
@set_fitted
9696
def fit(self, df, target, **fit_kwargs):
9797
"""
9898
Train a TPOTRegressor or TPOTClassifier by fitting on a dataframe.
@@ -112,14 +112,14 @@ def fit(self, df, target, **fit_kwargs):
112112
X = df.drop(columns=target).values.tolist()
113113
self._features = df.drop(columns=target).columns.tolist()
114114
self._ml_data = {"X": X, "y": y}
115-
self.is_fit = True
116115
self.fitted_target = target
117116
self.logger.info("TPOT fitting started.")
118117
self._backend = self._backend.fit(X, y, **fit_kwargs)
119118
self.logger.info("TPOT fitting finished.")
120119
return self
121120

122121

122+
@check_fitted
123123
@property
124124
def _best_models(self):
125125
"""
@@ -134,9 +134,6 @@ def _best_models(self):
134134
best hyperparameter combination found.
135135
136136
"""
137-
if not self.is_fit:
138-
raise NotFittedError("Error, the model has not yet been fit")
139-
140137
self.greater_score_is_better = is_greater_better(
141138
self.backend.scoring_function)
142139

@@ -179,6 +176,7 @@ def _best_models(self):
179176
self.models = models
180177
return best_models_and_scores
181178

179+
@check_fitted
182180
def predict(self, df, target):
183181
"""
184182
Predict the target property of materials given a df of features.
@@ -231,8 +229,8 @@ def predict(self, df, target):
231229

232230
# Load a dataset
233231
df = load_dataset("elastic_tensor_2015").rename(columns={"formula": "composition"})[["composition", "K_VRH"]]
234-
testdf = df.iloc[60:90]
235-
traindf = df.iloc[:500]
232+
testdf = df.iloc[501:550]
233+
traindf = df.iloc[:100]
236234
target = "K_VRH"
237235

238236
# Get top-lvel transformers
@@ -245,13 +243,11 @@ def predict(self, df, target):
245243
traindf = autofeater.fit_transform(traindf, target)
246244
traindf = cleaner.fit_transform(traindf, target)
247245
traindf = reducer.fit_transform(traindf, target)
246+
learner.fit(traindf, target)
248247

249248
# Use transformers on testing data
250249
testdf = autofeater.transform(testdf, target)
251250
testdf = cleaner.transform(testdf, target)
252251
testdf = reducer.transform(testdf, target)
253-
254-
# Use training data to predict testing data
255-
learner.fit(traindf, target)
256252
testdf = learner.predict(testdf, target)
257253
print(testdf)

matbench/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import logging
6+
67
from matbench.utils.utils import initialize_logger, initialize_null_logger
78

89
__authors__ = ["Alex Dunn <[email protected]>", "Alex Ganose <[email protected]>"]

matbench/featurization/core.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pymatgen import Composition
33
from matminer.featurizers.conversions import StructureToOxidStructure, StrToComposition, DictToObject, StructureToComposition
44

5-
from matbench.utils.utils import MatbenchError
5+
from matbench.utils.utils import MatbenchError, check_fitted, set_fitted
66
from matbench.base import DataframeTransformer, LoggableMixin
77
from matbench.featurization.sets import CompositionFeaturizers, \
88
StructureFeaturizers, BSFeaturizers, DOSFeaturizers
@@ -197,6 +197,7 @@ def __init__(self, featurizers=None, ignore_cols=None, ignore_errors=True,
197197
featurizers[ftype] = []
198198
self.featurizers = featurizers
199199

200+
@set_fitted
200201
def fit(self, df, target):
201202
"""
202203
Fit all featurizers to the df.
@@ -219,7 +220,6 @@ def fit(self, df, target):
219220
Returns:
220221
(AutoFeaturizer): self
221222
"""
222-
self.is_fit = False
223223
df = self._prescreen_df(df, inplace=True)
224224
df = self._add_composition_from_structure(df)
225225
for featurizer_type, featurizers in self.featurizers.items():
@@ -234,9 +234,9 @@ def fit(self, df, target):
234234
self.features += f.feature_labels()
235235
self.logger.info("Fit {} to {} samples in dataframe."
236236
"".format(f.__class__.__name__, df.shape[0]))
237-
self.is_fit = True
238237
return self
239238

239+
@check_fitted
240240
def transform(self, df, target):
241241
"""
242242
Decorate a dataframe containing composition, structure, bandstructure,
@@ -249,9 +249,6 @@ def transform(self, df, target):
249249
Returns:
250250
df (pandas.DataFrame): Transformed dataframe containing features.
251251
"""
252-
if not self.is_fit:
253-
# Featurization requires featurizers already be fit...
254-
raise NotFittedError("AutoFeaturizer has not been fit!")
255252
df = self._prescreen_df(df, inplace=True)
256253
df = self._add_composition_from_structure(df)
257254

@@ -369,6 +366,7 @@ def _add_composition_from_structure(self, df):
369366
df = struct2comp.featurize_dataframe(df, "structure")
370367
return df
371368

369+
372370
if __name__ == "__main__":
373371
from matminer.datasets.dataset_retrieval import load_dataset
374372
df = load_dataset("flla")

matbench/featurization/tests/test_core.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@
66
from pymatgen import Composition
77
from matminer.data_retrieval.retrieve_MP import MPDataRetrieval
88
from matminer.datasets.dataset_retrieval import load_dataset
9-
from matminer.featurizers.composition import ElectronAffinity, ElementProperty, AtomicOrbitals
10-
from matminer.featurizers.structure import GlobalSymmetryFeatures, DensityFeatures
9+
from matminer.featurizers.composition import ElectronAffinity, ElementProperty, \
10+
AtomicOrbitals
11+
from matminer.featurizers.structure import GlobalSymmetryFeatures, \
12+
DensityFeatures
1113

1214
from matbench.featurization.core import AutoFeaturizer
1315

1416
test_dir = os.path.dirname(__file__)
1517

16-
__author__ = ["Alex Dunn <[email protected]>", "Alireza Faghaninia <[email protected]>"]
18+
__author__ = ["Alex Dunn <[email protected]>",
19+
"Alireza Faghaninia <[email protected]>"]
20+
1721

1822
class TestAutoFeaturizer(unittest.TestCase):
1923

2024
def setUp(self, limit=5):
21-
self.test_df = load_dataset('elastic_tensor_2015').rename(columns={"formula": "composition"})
25+
self.test_df = load_dataset('elastic_tensor_2015').rename(
26+
columns={"formula": "composition"})
2227
self.limit = limit
2328

2429
def test_sanity(self):
@@ -53,7 +58,6 @@ def test_featurize_composition(self):
5358
self.assertEqual(df["LUMO_element"].iloc[0], "Nb")
5459
self.assertTrue("composition" not in df.columns)
5560

56-
5761
def test_featurize_structure(self):
5862
"""
5963
Test automatic featurization while only considering structure.
@@ -120,7 +124,6 @@ def test_exclusions(self):
120124
for flabels in [ep_feats, ef_feats, ao_feats]:
121125
self.assertFalse(any([f in df.columns for f in flabels]))
122126

123-
124127
def test_featurize_bsdos(self, refresh_df_init=False, limit=1):
125128
"""
126129
Tests featurize_dos and featurize_bandstructure.
@@ -138,11 +141,11 @@ def test_featurize_bsdos(self, refresh_df_init=False, limit=1):
138141
if refresh_df_init:
139142
mpdr = MPDataRetrieval()
140143
df = mpdr.get_dataframe(criteria={"material_id": "mp-149"},
141-
properties=["pretty_formula",
142-
"dos",
143-
"bandstructure",
144-
"bandstructure_uniform"]
145-
)
144+
properties=["pretty_formula",
145+
"dos",
146+
"bandstructure",
147+
"bandstructure_uniform"]
148+
)
146149
df.to_pickle(os.path.join(test_dir, df_bsdos_pickled))
147150
else:
148151
df = pd.read_pickle(os.path.join(test_dir, df_bsdos_pickled))

0 commit comments

Comments
 (0)