Skip to content

Commit a6e89ac

Browse files
authored
Merge pull request #202 from alexbrillant/data-shuffling-epochs-repeater
Add Data Shuffling, And Epochs Repeater Steps
2 parents 52c07bc + b7abc39 commit a6e89ac

File tree

15 files changed

+477
-108
lines changed

15 files changed

+477
-108
lines changed

neuraxle/base.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def apply(self, method_name: str, *kargs, **kwargs) -> 'BaseStep':
881881

882882
return self
883883

884-
def handle_fit(self, data_container: DataContainer, context: ExecutionContext) -> ('BaseStep', DataContainer):
884+
def handle_fit(self, data_container: DataContainer, context: ExecutionContext) -> 'BaseStep':
885885
"""
886886
Override this to add side effects or change the execution flow before (or after) calling :func:`~neuraxle.base.BaseStep.fit`.
887887
The default behavior is to rehash current ids with the step hyperparameters.
@@ -897,12 +897,9 @@ def handle_fit(self, data_container: DataContainer, context: ExecutionContext) -
897897
data_container, context = self._will_process(data_container, context)
898898
data_container, context = self._will_fit(data_container, context)
899899

900-
new_self, data_container = self._fit_data_container(data_container, context)
900+
new_self = self._fit_data_container(data_container, context)
901901

902-
data_container = self._did_fit(data_container, context)
903-
data_container = self._did_process(data_container, context)
904-
905-
return new_self, data_container
902+
return new_self
906903

907904
def handle_fit_transform(self, data_container: DataContainer, context: ExecutionContext) -> ('BaseStep', DataContainer):
908905
"""
@@ -965,7 +962,7 @@ def _did_fit(self, data_container: DataContainer, context: ExecutionContext) ->
965962
"""
966963
return data_container
967964

968-
def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> ('BaseStep', DataContainer):
965+
def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> 'BaseStep':
969966
"""
970967
Fit data container.
971968
@@ -974,8 +971,7 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
974971
:return: (fitted self, data container)
975972
:rtype: (BaseStep, DataContainer)
976973
"""
977-
new_self = self.fit(data_container.data_inputs, data_container.expected_outputs)
978-
return new_self, data_container
974+
return self.fit(data_container.data_inputs, data_container.expected_outputs)
979975

980976
def _will_fit_transform(self, data_container: DataContainer, context: ExecutionContext) -> (DataContainer, ExecutionContext):
981977
"""
@@ -1634,7 +1630,7 @@ def get_hyperparams(self) -> HyperparameterSamples:
16341630
"""
16351631
return HyperparameterSamples({
16361632
**self.hyperparams.to_flat_as_dict_primitive(),
1637-
self.wrapped.name: self.wrapped.hyperparams.to_flat_as_dict_primitive()
1633+
self.wrapped.name: self.wrapped.get_hyperparams().to_flat_as_dict_primitive()
16381634
}).to_flat()
16391635

16401636
def set_hyperparams_space(self, hyperparams_space: HyperparameterSpace) -> 'BaseStep':
@@ -1670,7 +1666,7 @@ def get_hyperparams_space(self) -> HyperparameterSpace:
16701666
"""
16711667
return HyperparameterSpace({
16721668
**self.hyperparams_space.to_flat_as_dict_primitive(),
1673-
self.wrapped.name: self.wrapped.hyperparams_space.to_flat_as_dict_primitive()
1669+
self.wrapped.name: self.wrapped.get_hyperparams_space().to_flat_as_dict_primitive()
16741670
}).to_flat()
16751671

16761672
def set_step(self, step: BaseStep) -> BaseStep:
@@ -1703,8 +1699,8 @@ def _fit_transform_data_container(self, data_container, context):
17031699
return self, data_container
17041700

17051701
def _fit_data_container(self, data_container, context):
1706-
self.wrapped, data_container = self.wrapped.handle_fit(data_container, context)
1707-
return self, data_container
1702+
self.wrapped = self.wrapped.handle_fit(data_container, context)
1703+
return self
17081704

17091705
def _transform_data_container(self, data_container, context):
17101706
data_container = self.wrapped.handle_transform(data_container, context)
@@ -1756,6 +1752,39 @@ def apply_method(self, method: Callable, *kargs, **kwargs) -> 'BaseStep':
17561752
self.wrapped = self.wrapped.apply_method(method, *kargs, **kwargs)
17571753
return self
17581754

1755+
1756+
def mutate(self, new_method="inverse_transform", method_to_assign_to="transform", warn=True) -> 'BaseStep':
1757+
"""
1758+
Mutate self, and self.wrapped. Please refer to :func:`~neuraxle.base.BaseStep.mutate` for more information.
1759+
1760+
:param new_method: the method to replace transform with, if there is no pending ``will_mutate_to`` call.
1761+
:param method_to_assign_to: the method to which the new method will be assigned to, if there is no pending ``will_mutate_to`` call.
1762+
:param warn: (verbose) wheter or not to warn about the inexistence of the method.
1763+
:return: self, a copy of self, or even perhaps a new or different BaseStep object.
1764+
"""
1765+
new_self = BaseStep.mutate(self, new_method, method_to_assign_to, warn)
1766+
self.wrapped = self.wrapped.mutate(new_method, method_to_assign_to, warn)
1767+
1768+
return new_self
1769+
1770+
def will_mutate_to(
1771+
self, new_base_step: 'BaseStep' = None, new_method: str = None, method_to_assign_to: str = None
1772+
) -> 'BaseStep':
1773+
"""
1774+
Add pending mutate self, self.wrapped. Please refer to :func:`~neuraxle.base.BaseStep.will_mutate_to` for more information.
1775+
1776+
:param new_base_step: if it is not None, upon calling ``mutate``, the object it will mutate to will be this provided new_base_step.
1777+
:type new_base_step: BaseStep
1778+
:param method_to_assign_to: if it is not None, upon calling ``mutate``, the method_to_affect will be the one that is used on the provided new_base_step.
1779+
:type method_to_assign_to: str
1780+
:param new_method: if it is not None, upon calling ``mutate``, the new_method will be the one that is used on the provided new_base_step.
1781+
:type new_method: str
1782+
:return: self
1783+
:rtype: BaseStep
1784+
"""
1785+
new_self = BaseStep.will_mutate_to(self, new_base_step, new_method, method_to_assign_to)
1786+
return new_self
1787+
17591788
def __repr__(self):
17601789
output = self.__class__.__name__ + "(\n\twrapped=" + repr(
17611790
self.wrapped) + "," + "\n\thyperparameters=" + pprint.pformat(

neuraxle/checkpoints.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(
207207
BaseStep.__init__(self)
208208
self.all_checkpointers = all_checkpointers
209209

210-
def _fit_data_container(self, data_container, context) -> Tuple['Checkpoint', DataContainer]:
210+
def _fit_data_container(self, data_container, context) -> 'Checkpoint':
211211
"""
212212
Saves step, and data checkpointers for the FIT execution mode.
213213
@@ -217,7 +217,7 @@ def _fit_data_container(self, data_container, context) -> Tuple['Checkpoint', Da
217217
:rtype: neuraxle.data_container.DataContainer
218218
"""
219219
self.save_checkpoint(data_container, context)
220-
return self, data_container
220+
return self
221221

222222
def _transform_data_container(self, data_container, context):
223223
"""
@@ -715,7 +715,6 @@ def should_resume(self, data_container: DataContainer, context: ExecutionContext
715715
if not self.summary_checkpointer.checkpoint_exists(context.get_path(), data_container):
716716
return False
717717

718-
719718
current_ids = self.summary_checkpointer.read_summary(
720719
checkpoint_path=context.get_path(),
721720
data_container=data_container

neuraxle/metaopt/random.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"""
2626

2727
import copy
28+
import json
2829
import math
2930
from abc import ABC, abstractmethod
3031
from typing import List, Callable, Tuple, Iterable
@@ -141,7 +142,7 @@ def _fit_transform_data_container(self, data_container: DataContainer, context:
141142
"""
142143
train_data_container, validation_data_container = self.split_data_container(data_container)
143144

144-
self.wrapped, _ = self.wrapped.handle_fit(train_data_container, context.push(self.wrapped))
145+
self.wrapped = self.wrapped.handle_fit(train_data_container, context.push(self.wrapped))
145146

146147
results_data_container = self.wrapped.handle_transform(train_data_container, context.push(self.wrapped))
147148

@@ -581,17 +582,54 @@ def __init__(
581582
wrapped=None,
582583
n_iter: int = 10,
583584
higher_score_is_better: bool = True,
584-
validation_technique: BaseCrossValidationWrapper = KFoldCrossValidationWrapper(),
585+
validation_technique: BaseValidation = KFoldCrossValidationWrapper(),
585586
refit=True,
586587
):
587588
if wrapped is not None:
588589
MetaStepMixin.__init__(self, wrapped)
589590
BaseStep.__init__(self)
590591
self.n_iter = n_iter
591592
self.higher_score_is_better = higher_score_is_better
592-
self.validation_technique: BaseCrossValidationWrapper = validation_technique
593+
self.validation_technique: BaseValidation = validation_technique
593594
self.refit = refit
594595

596+
def _fit_transform_data_container(self, data_container, context):
597+
fitted_self = self._fit_data_container(data_container, context)
598+
best_model_predictions_data_container = self._transform_data_container(data_container, context)
599+
return fitted_self, best_model_predictions_data_container
600+
601+
def _fit_data_container(self, data_container, context):
602+
started = False
603+
best_hyperparams = None
604+
605+
for _ in range(self.n_iter):
606+
607+
step = copy.copy(self.wrapped)
608+
609+
new_hyperparams = step.get_hyperparams_space().rvs()
610+
step.update_hyperparams(new_hyperparams)
611+
612+
step: BaseValidation = copy.copy(self.validation_technique).set_step(step)
613+
614+
step = step.handle_fit(data_container, context)
615+
score = step.scores_mean
616+
617+
if not started or self.higher_score_is_better == (score > self.score):
618+
started = True
619+
self.score = score
620+
self.best_validation_wrapper_of_model = copy.copy(step)
621+
print('score: {}'.format(score))
622+
best_hyperparams = new_hyperparams
623+
print('best_hyperparams: \n{}\n'.format(best_hyperparams))
624+
625+
self.best_validation_wrapper_of_model.wrapped.update_hyperparams(best_hyperparams)
626+
627+
self.best_model = copy.copy(self.wrapped).update_hyperparams(best_hyperparams)
628+
if self.refit:
629+
self.best_model = self.best_model.handle_fit(data_container, context)
630+
631+
return self
632+
595633
def fit_transform(self, data_inputs, expected_outputs):
596634
return self.fit(data_inputs, expected_outputs), self.transform(data_inputs)
597635

@@ -606,7 +644,7 @@ def fit(self, data_inputs, expected_outputs=None) -> 'BaseStep':
606644
new_hyperparams = step.get_hyperparams_space().rvs()
607645
step.set_hyperparams(new_hyperparams)
608646

609-
step: BaseCrossValidationWrapper = copy.copy(self.validation_technique).set_step(step)
647+
step: BaseValidation = copy.copy(self.validation_technique).set_step(step)
610648

611649
step = step.fit(data_inputs, expected_outputs)
612650
score = step.scores_mean
@@ -615,15 +653,17 @@ def fit(self, data_inputs, expected_outputs=None) -> 'BaseStep':
615653
started = True
616654
self.score = score
617655
self.best_validation_wrapper_of_model = copy.copy(step)
656+
657+
print('\nbest_score: {}'.format(score))
618658
best_hyperparams = new_hyperparams
659+
print('best_hyperparams: ')
660+
print(json.dumps(best_hyperparams.to_nested_dict(), sort_keys=True, indent=4))
619661

620662
self.best_validation_wrapper_of_model.wrapped.set_hyperparams(best_hyperparams)
621663

664+
self.best_model = copy.copy(self.wrapped).set_hyperparams(best_hyperparams)
622665
if self.refit:
623-
self.best_model = self.best_validation_wrapper_of_model.wrapped.fit(
624-
data_inputs,
625-
expected_outputs
626-
)
666+
self.best_model = self.best_model.fit(data_inputs, expected_outputs)
627667

628668
return self
629669

@@ -634,3 +674,9 @@ def transform(self, data_inputs):
634674
if self.best_validation_wrapper_of_model is None:
635675
raise Exception('Cannot transform RandomSearch before fit')
636676
return self.best_validation_wrapper_of_model.wrapped.transform(data_inputs)
677+
678+
def _transform_data_container(self, data_container, context):
679+
if self.best_validation_wrapper_of_model is None:
680+
raise Exception('Cannot transform RandomSearch before fit')
681+
682+
return self.best_validation_wrapper_of_model.wrapped.handle_transform(data_container, context)

neuraxle/pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def fit(self, data_inputs, expected_outputs=None) -> 'Pipeline':
116116
data_container = self.hash_data_container(data_container)
117117
context = ExecutionContext(self.cache_folder, ExecutionMode.FIT)
118118
context = context.push(self)
119-
new_self, data_container = self._fit_data_container(data_container, context)
119+
new_self = self._fit_data_container(data_container, context)
120120

121121
return new_self
122122

@@ -131,7 +131,7 @@ def inverse_transform(self, processed_outputs) -> Any:
131131
processed_outputs = step.inverse_transform(processed_outputs)
132132
return processed_outputs
133133

134-
def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> ('Pipeline', DataContainer):
134+
def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> 'Pipeline':
135135
"""
136136
After loading the last checkpoint, fit transform each pipeline steps,
137137
but only fit the last pipeline step.
@@ -153,14 +153,14 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
153153
if index != index_last_step:
154154
step, data_container = step.handle_fit_transform(data_container, context)
155155
else:
156-
step, data_container = step.handle_fit(data_container, context)
156+
step = step.handle_fit(data_container, context)
157157

158158
new_steps_as_tuple.append((step_name, step))
159159

160160
self.steps_as_tuple = self.steps_as_tuple[
161161
:len(self.steps_as_tuple) - len(steps_left_to_do)] + new_steps_as_tuple
162162

163-
return self, data_container
163+
return self
164164

165165
def _fit_transform_data_container(self, data_container: DataContainer, context: ExecutionContext) -> (
166166
'Pipeline', DataContainer):
@@ -334,7 +334,7 @@ def fit(self, data_inputs, expected_outputs=None) -> 'Pipeline':
334334
data_container.set_current_ids(current_ids)
335335

336336
context = ExecutionContext(self.cache_folder, ExecutionMode.FIT_TRANSFORM)
337-
new_self, data_container = self.handle_fit(data_container, context)
337+
new_self = self.handle_fit(data_container, context)
338338

339339
return new_self
340340

neuraxle/steps/caching.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,18 @@
2222
project, visit https://www.umaneo.com/ for more information on Umaneo Technologies Inc.
2323
2424
"""
25+
import hashlib
2526
import os
2627
import pickle
2728
import shutil
28-
from abc import abstractmethod
29+
from abc import abstractmethod, ABC
2930
from typing import Iterable, Any
3031

3132
from neuraxle.base import MetaStepMixin, BaseStep, NonFittableMixin, NonTransformableMixin, \
3233
ExecutionContext
3334
from neuraxle.data_container import DataContainer
3435
from neuraxle.pipeline import DEFAULT_CACHE_FOLDER
35-
from neuraxle.steps.misc import BaseValueHasher, Md5Hasher, VALUE_CACHING
36+
from neuraxle.steps.misc import VALUE_CACHING
3637

3738

3839
class ValueCachingWrapper(MetaStepMixin, NonFittableMixin, NonTransformableMixin, BaseStep):
@@ -44,7 +45,7 @@ def __init__(
4445
self,
4546
wrapped: BaseStep,
4647
cache_folder: str = DEFAULT_CACHE_FOLDER,
47-
value_hasher: BaseValueHasher = None,
48+
value_hasher: 'BaseValueHasher' = None,
4849
):
4950
BaseStep.__init__(self)
5051
MetaStepMixin.__init__(self, wrapped)
@@ -214,3 +215,17 @@ def contains_cache_for(self, data_input) -> bool:
214215
def get_cache_path_for(self, data_input):
215216
hash_value = self._hash_value(data_input)
216217
return os.path.join(self.checkpoint_path, '{0}.pickle'.format(hash_value))
218+
219+
220+
class BaseValueHasher(ABC):
221+
@abstractmethod
222+
def hash(self, data_input):
223+
raise NotImplementedError()
224+
225+
226+
class Md5Hasher(BaseValueHasher):
227+
def hash(self, data_input):
228+
m = hashlib.md5()
229+
m.update(str.encode(str(data_input)))
230+
231+
return m.hexdigest()

0 commit comments

Comments
 (0)