|
10 | 10 | # License: MIT License
|
11 | 11 |
|
12 | 12 | import numpy as np
|
13 |
| -import warnings |
14 | 13 |
|
15 | 14 | from .bregman import sinkhorn
|
16 | 15 | from .lp import emd
|
17 | 16 | from .utils import unif, dist, kernel
|
| 17 | +from .utils import deprecated, BaseEstimator |
18 | 18 | from .optim import cg
|
19 | 19 | from .optim import gcg
|
20 |
| -from .deprecation import deprecated |
21 | 20 |
|
22 | 21 |
|
23 | 22 | def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
|
@@ -936,139 +935,6 @@ def predict(self, x):
|
936 | 935 | print("Warning, model not fitted yet, returning None")
|
937 | 936 | return None
|
938 | 937 |
|
939 |
| -############################################################################## |
940 |
| -# proposal |
941 |
| -############################################################################## |
942 |
| - |
943 |
| - |
944 |
| -# adapted from sklearn |
945 |
| - |
946 |
| -class BaseEstimator(object): |
947 |
| - """Base class for all estimators in scikit-learn |
948 |
| - Notes |
949 |
| - ----- |
950 |
| - All estimators should specify all the parameters that can be set |
951 |
| - at the class level in their ``__init__`` as explicit keyword |
952 |
| - arguments (no ``*args`` or ``**kwargs``). |
953 |
| - """ |
954 |
| - |
955 |
| - @classmethod |
956 |
| - def _get_param_names(cls): |
957 |
| - """Get parameter names for the estimator""" |
958 |
| - try: |
959 |
| - from inspect import signature |
960 |
| - except ImportError: |
961 |
| - from .externals.funcsigs import signature |
962 |
| - # fetch the constructor or the original constructor before |
963 |
| - # deprecation wrapping if any |
964 |
| - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) |
965 |
| - if init is object.__init__: |
966 |
| - # No explicit constructor to introspect |
967 |
| - return [] |
968 |
| - |
969 |
| - # introspect the constructor arguments to find the model parameters |
970 |
| - # to represent |
971 |
| - init_signature = signature(init) |
972 |
| - # Consider the constructor parameters excluding 'self' |
973 |
| - parameters = [p for p in init_signature.parameters.values() |
974 |
| - if p.name != 'self' and p.kind != p.VAR_KEYWORD] |
975 |
| - for p in parameters: |
976 |
| - if p.kind == p.VAR_POSITIONAL: |
977 |
| - raise RuntimeError("scikit-learn estimators should always " |
978 |
| - "specify their parameters in the signature" |
979 |
| - " of their __init__ (no varargs)." |
980 |
| - " %s with constructor %s doesn't " |
981 |
| - " follow this convention." |
982 |
| - % (cls, init_signature)) |
983 |
| - # Extract and sort argument names excluding 'self' |
984 |
| - return sorted([p.name for p in parameters]) |
985 |
| - |
986 |
| - def get_params(self, deep=True): |
987 |
| - """Get parameters for this estimator. |
988 |
| -
|
989 |
| - Parameters |
990 |
| - ---------- |
991 |
| - deep : boolean, optional |
992 |
| - If True, will return the parameters for this estimator and |
993 |
| - contained subobjects that are estimators. |
994 |
| -
|
995 |
| - Returns |
996 |
| - ------- |
997 |
| - params : mapping of string to any |
998 |
| - Parameter names mapped to their values. |
999 |
| - """ |
1000 |
| - out = dict() |
1001 |
| - for key in self._get_param_names(): |
1002 |
| - # We need deprecation warnings to always be on in order to |
1003 |
| - # catch deprecated param values. |
1004 |
| - # This is set in utils/__init__.py but it gets overwritten |
1005 |
| - # when running under python3 somehow. |
1006 |
| - warnings.simplefilter("always", DeprecationWarning) |
1007 |
| - try: |
1008 |
| - with warnings.catch_warnings(record=True) as w: |
1009 |
| - value = getattr(self, key, None) |
1010 |
| - if len(w) and w[0].category == DeprecationWarning: |
1011 |
| - # if the parameter is deprecated, don't show it |
1012 |
| - continue |
1013 |
| - finally: |
1014 |
| - warnings.filters.pop(0) |
1015 |
| - |
1016 |
| - # XXX: should we rather test if instance of estimator? |
1017 |
| - if deep and hasattr(value, 'get_params'): |
1018 |
| - deep_items = value.get_params().items() |
1019 |
| - out.update((key + '__' + k, val) for k, val in deep_items) |
1020 |
| - out[key] = value |
1021 |
| - return out |
1022 |
| - |
1023 |
| - def set_params(self, **params): |
1024 |
| - """Set the parameters of this estimator. |
1025 |
| -
|
1026 |
| - The method works on simple estimators as well as on nested objects |
1027 |
| - (such as pipelines). The latter have parameters of the form |
1028 |
| - ``<component>__<parameter>`` so that it's possible to update each |
1029 |
| - component of a nested object. |
1030 |
| -
|
1031 |
| - Returns |
1032 |
| - ------- |
1033 |
| - self |
1034 |
| - """ |
1035 |
| - if not params: |
1036 |
| - # Simple optimisation to gain speed (inspect is slow) |
1037 |
| - return self |
1038 |
| - valid_params = self.get_params(deep=True) |
1039 |
| - # for key, value in iteritems(params): |
1040 |
| - for key, value in params.items(): |
1041 |
| - split = key.split('__', 1) |
1042 |
| - if len(split) > 1: |
1043 |
| - # nested objects case |
1044 |
| - name, sub_name = split |
1045 |
| - if name not in valid_params: |
1046 |
| - raise ValueError('Invalid parameter %s for estimator %s. ' |
1047 |
| - 'Check the list of available parameters ' |
1048 |
| - 'with `estimator.get_params().keys()`.' % |
1049 |
| - (name, self)) |
1050 |
| - sub_object = valid_params[name] |
1051 |
| - sub_object.set_params(**{sub_name: value}) |
1052 |
| - else: |
1053 |
| - # simple objects case |
1054 |
| - if key not in valid_params: |
1055 |
| - raise ValueError('Invalid parameter %s for estimator %s. ' |
1056 |
| - 'Check the list of available parameters ' |
1057 |
| - 'with `estimator.get_params().keys()`.' % |
1058 |
| - (key, self.__class__.__name__)) |
1059 |
| - setattr(self, key, value) |
1060 |
| - return self |
1061 |
| - |
1062 |
| - def __repr__(self): |
1063 |
| - from sklearn.base import _pprint |
1064 |
| - class_name = self.__class__.__name__ |
1065 |
| - return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), |
1066 |
| - offset=len(class_name),),) |
1067 |
| - |
1068 |
| - # __getstate__ and __setstate__ are omitted because they only contain |
1069 |
| - # conditionals that are not satisfied by our objects (e.g., |
1070 |
| - # ``if type(self).__module__.startswith('sklearn.')``. |
1071 |
| - |
1072 | 938 |
|
1073 | 939 | def distribution_estimation_uniform(X):
|
1074 | 940 | """estimates a uniform distribution from an array of samples X
|
|
0 commit comments