Skip to content

Commit 55840f6

Browse files
committed
move no da objects into utils.py
1 parent 892d7ce commit 55840f6

File tree

4 files changed

+222
-238
lines changed

4 files changed

+222
-238
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ The contributors to this library are:
136136
* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/)
137137
* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation)
138138
* [Léo Gautheron](https://github.com/aje) (GPU implementation)
139+
* [Nathalie Gayraud]()
140+
* [Stanislas Chambon](https://slasnista.github.io/)
139141

140142
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
141143

ot/da.py

Lines changed: 1 addition & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010
# License: MIT License
1111

1212
import numpy as np
13-
import warnings
1413

1514
from .bregman import sinkhorn
1615
from .lp import emd
1716
from .utils import unif, dist, kernel
17+
from .utils import deprecated, BaseEstimator
1818
from .optim import cg
1919
from .optim import gcg
20-
from .deprecation import deprecated
2120

2221

2322
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -936,139 +935,6 @@ def predict(self, x):
936935
print("Warning, model not fitted yet, returning None")
937936
return None
938937

939-
##############################################################################
940-
# proposal
941-
##############################################################################
942-
943-
944-
# adapted from sklearn
945-
946-
class BaseEstimator(object):
947-
"""Base class for all estimators in scikit-learn
948-
Notes
949-
-----
950-
All estimators should specify all the parameters that can be set
951-
at the class level in their ``__init__`` as explicit keyword
952-
arguments (no ``*args`` or ``**kwargs``).
953-
"""
954-
955-
@classmethod
956-
def _get_param_names(cls):
957-
"""Get parameter names for the estimator"""
958-
try:
959-
from inspect import signature
960-
except ImportError:
961-
from .externals.funcsigs import signature
962-
# fetch the constructor or the original constructor before
963-
# deprecation wrapping if any
964-
init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
965-
if init is object.__init__:
966-
# No explicit constructor to introspect
967-
return []
968-
969-
# introspect the constructor arguments to find the model parameters
970-
# to represent
971-
init_signature = signature(init)
972-
# Consider the constructor parameters excluding 'self'
973-
parameters = [p for p in init_signature.parameters.values()
974-
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
975-
for p in parameters:
976-
if p.kind == p.VAR_POSITIONAL:
977-
raise RuntimeError("scikit-learn estimators should always "
978-
"specify their parameters in the signature"
979-
" of their __init__ (no varargs)."
980-
" %s with constructor %s doesn't "
981-
" follow this convention."
982-
% (cls, init_signature))
983-
# Extract and sort argument names excluding 'self'
984-
return sorted([p.name for p in parameters])
985-
986-
def get_params(self, deep=True):
987-
"""Get parameters for this estimator.
988-
989-
Parameters
990-
----------
991-
deep : boolean, optional
992-
If True, will return the parameters for this estimator and
993-
contained subobjects that are estimators.
994-
995-
Returns
996-
-------
997-
params : mapping of string to any
998-
Parameter names mapped to their values.
999-
"""
1000-
out = dict()
1001-
for key in self._get_param_names():
1002-
# We need deprecation warnings to always be on in order to
1003-
# catch deprecated param values.
1004-
# This is set in utils/__init__.py but it gets overwritten
1005-
# when running under python3 somehow.
1006-
warnings.simplefilter("always", DeprecationWarning)
1007-
try:
1008-
with warnings.catch_warnings(record=True) as w:
1009-
value = getattr(self, key, None)
1010-
if len(w) and w[0].category == DeprecationWarning:
1011-
# if the parameter is deprecated, don't show it
1012-
continue
1013-
finally:
1014-
warnings.filters.pop(0)
1015-
1016-
# XXX: should we rather test if instance of estimator?
1017-
if deep and hasattr(value, 'get_params'):
1018-
deep_items = value.get_params().items()
1019-
out.update((key + '__' + k, val) for k, val in deep_items)
1020-
out[key] = value
1021-
return out
1022-
1023-
def set_params(self, **params):
1024-
"""Set the parameters of this estimator.
1025-
1026-
The method works on simple estimators as well as on nested objects
1027-
(such as pipelines). The latter have parameters of the form
1028-
``<component>__<parameter>`` so that it's possible to update each
1029-
component of a nested object.
1030-
1031-
Returns
1032-
-------
1033-
self
1034-
"""
1035-
if not params:
1036-
# Simple optimisation to gain speed (inspect is slow)
1037-
return self
1038-
valid_params = self.get_params(deep=True)
1039-
# for key, value in iteritems(params):
1040-
for key, value in params.items():
1041-
split = key.split('__', 1)
1042-
if len(split) > 1:
1043-
# nested objects case
1044-
name, sub_name = split
1045-
if name not in valid_params:
1046-
raise ValueError('Invalid parameter %s for estimator %s. '
1047-
'Check the list of available parameters '
1048-
'with `estimator.get_params().keys()`.' %
1049-
(name, self))
1050-
sub_object = valid_params[name]
1051-
sub_object.set_params(**{sub_name: value})
1052-
else:
1053-
# simple objects case
1054-
if key not in valid_params:
1055-
raise ValueError('Invalid parameter %s for estimator %s. '
1056-
'Check the list of available parameters '
1057-
'with `estimator.get_params().keys()`.' %
1058-
(key, self.__class__.__name__))
1059-
setattr(self, key, value)
1060-
return self
1061-
1062-
def __repr__(self):
1063-
from sklearn.base import _pprint
1064-
class_name = self.__class__.__name__
1065-
return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
1066-
offset=len(class_name),),)
1067-
1068-
# __getstate__ and __setstate__ are omitted because they only contain
1069-
# conditionals that are not satisfied by our objects (e.g.,
1070-
# ``if type(self).__module__.startswith('sklearn.')``.
1071-
1072938

1073939
def distribution_estimation_uniform(X):
1074940
"""estimates a uniform distribution from an array of samples X

ot/deprecation.py

Lines changed: 0 additions & 103 deletions
This file was deleted.

0 commit comments

Comments
 (0)