@@ -881,7 +881,7 @@ def apply(self, method_name: str, *kargs, **kwargs) -> 'BaseStep':
881
881
882
882
return self
883
883
884
- def handle_fit (self , data_container : DataContainer , context : ExecutionContext ) -> ( 'BaseStep' , DataContainer ) :
884
+ def handle_fit (self , data_container : DataContainer , context : ExecutionContext ) -> 'BaseStep' :
885
885
"""
886
886
Override this to add side effects or change the execution flow before (or after) calling :func:`~neuraxle.base.BaseStep.fit`.
887
887
The default behavior is to rehash current ids with the step hyperparameters.
@@ -897,12 +897,9 @@ def handle_fit(self, data_container: DataContainer, context: ExecutionContext) -
897
897
data_container , context = self ._will_process (data_container , context )
898
898
data_container , context = self ._will_fit (data_container , context )
899
899
900
- new_self , data_container = self ._fit_data_container (data_container , context )
900
+ new_self = self ._fit_data_container (data_container , context )
901
901
902
- data_container = self ._did_fit (data_container , context )
903
- data_container = self ._did_process (data_container , context )
904
-
905
- return new_self , data_container
902
+ return new_self
906
903
907
904
def handle_fit_transform (self , data_container : DataContainer , context : ExecutionContext ) -> ('BaseStep' , DataContainer ):
908
905
"""
@@ -965,7 +962,7 @@ def _did_fit(self, data_container: DataContainer, context: ExecutionContext) ->
965
962
"""
966
963
return data_container
967
964
968
- def _fit_data_container (self , data_container : DataContainer , context : ExecutionContext ) -> ( 'BaseStep' , DataContainer ) :
965
+ def _fit_data_container (self , data_container : DataContainer , context : ExecutionContext ) -> 'BaseStep' :
969
966
"""
970
967
Fit data container.
971
968
@@ -974,8 +971,7 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
974
971
:return: (fitted self, data container)
975
972
:rtype: (BaseStep, DataContainer)
976
973
"""
977
- new_self = self .fit (data_container .data_inputs , data_container .expected_outputs )
978
- return new_self , data_container
974
+ return self .fit (data_container .data_inputs , data_container .expected_outputs )
979
975
980
976
def _will_fit_transform (self , data_container : DataContainer , context : ExecutionContext ) -> (DataContainer , ExecutionContext ):
981
977
"""
@@ -1634,7 +1630,7 @@ def get_hyperparams(self) -> HyperparameterSamples:
1634
1630
"""
1635
1631
return HyperparameterSamples ({
1636
1632
** self .hyperparams .to_flat_as_dict_primitive (),
1637
- self .wrapped .name : self .wrapped .hyperparams .to_flat_as_dict_primitive ()
1633
+ self .wrapped .name : self .wrapped .get_hyperparams () .to_flat_as_dict_primitive ()
1638
1634
}).to_flat ()
1639
1635
1640
1636
def set_hyperparams_space (self , hyperparams_space : HyperparameterSpace ) -> 'BaseStep' :
@@ -1670,7 +1666,7 @@ def get_hyperparams_space(self) -> HyperparameterSpace:
1670
1666
"""
1671
1667
return HyperparameterSpace ({
1672
1668
** self .hyperparams_space .to_flat_as_dict_primitive (),
1673
- self .wrapped .name : self .wrapped .hyperparams_space .to_flat_as_dict_primitive ()
1669
+ self .wrapped .name : self .wrapped .get_hyperparams_space () .to_flat_as_dict_primitive ()
1674
1670
}).to_flat ()
1675
1671
1676
1672
def set_step (self , step : BaseStep ) -> BaseStep :
@@ -1703,8 +1699,8 @@ def _fit_transform_data_container(self, data_container, context):
1703
1699
return self , data_container
1704
1700
1705
1701
def _fit_data_container (self , data_container , context ):
1706
- self .wrapped , data_container = self .wrapped .handle_fit (data_container , context )
1707
- return self , data_container
1702
+ self .wrapped = self .wrapped .handle_fit (data_container , context )
1703
+ return self
1708
1704
1709
1705
def _transform_data_container (self , data_container , context ):
1710
1706
data_container = self .wrapped .handle_transform (data_container , context )
@@ -1756,6 +1752,39 @@ def apply_method(self, method: Callable, *kargs, **kwargs) -> 'BaseStep':
1756
1752
self .wrapped = self .wrapped .apply_method (method , * kargs , ** kwargs )
1757
1753
return self
1758
1754
1755
+
1756
+ def mutate (self , new_method = "inverse_transform" , method_to_assign_to = "transform" , warn = True ) -> 'BaseStep' :
1757
+ """
1758
+ Mutate self, and self.wrapped. Please refer to :func:`~neuraxle.base.BaseStep.mutate` for more information.
1759
+
1760
+ :param new_method: the method to replace transform with, if there is no pending ``will_mutate_to`` call.
1761
+ :param method_to_assign_to: the method to which the new method will be assigned to, if there is no pending ``will_mutate_to`` call.
1762
+ :param warn: (verbose) wheter or not to warn about the inexistence of the method.
1763
+ :return: self, a copy of self, or even perhaps a new or different BaseStep object.
1764
+ """
1765
+ new_self = BaseStep .mutate (self , new_method , method_to_assign_to , warn )
1766
+ self .wrapped = self .wrapped .mutate (new_method , method_to_assign_to , warn )
1767
+
1768
+ return new_self
1769
+
1770
+ def will_mutate_to (
1771
+ self , new_base_step : 'BaseStep' = None , new_method : str = None , method_to_assign_to : str = None
1772
+ ) -> 'BaseStep' :
1773
+ """
1774
+ Add pending mutate self, self.wrapped. Please refer to :func:`~neuraxle.base.BaseStep.will_mutate_to` for more information.
1775
+
1776
+ :param new_base_step: if it is not None, upon calling ``mutate``, the object it will mutate to will be this provided new_base_step.
1777
+ :type new_base_step: BaseStep
1778
+ :param method_to_assign_to: if it is not None, upon calling ``mutate``, the method_to_affect will be the one that is used on the provided new_base_step.
1779
+ :type method_to_assign_to: str
1780
+ :param new_method: if it is not None, upon calling ``mutate``, the new_method will be the one that is used on the provided new_base_step.
1781
+ :type new_method: str
1782
+ :return: self
1783
+ :rtype: BaseStep
1784
+ """
1785
+ new_self = BaseStep .will_mutate_to (self , new_base_step , new_method , method_to_assign_to )
1786
+ return new_self
1787
+
1759
1788
def __repr__ (self ):
1760
1789
output = self .__class__ .__name__ + "(\n \t wrapped=" + repr (
1761
1790
self .wrapped ) + "," + "\n \t hyperparameters=" + pprint .pformat (
0 commit comments