Skip to content

Commit 0930223

Browse files
committed
added deprecation warning on old classes
1 parent 326d163 commit 0930223

File tree

3 files changed

+126
-4
lines changed

3 files changed

+126
-4
lines changed

ot/da.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
# License: MIT License
1111

1212
import numpy as np
13+
import warnings
14+
1315
from .bregman import sinkhorn
1416
from .lp import emd
1517
from .utils import unif, dist, kernel
1618
from .optim import cg
1719
from .optim import gcg
18-
import warnings
20+
from .deprecation import deprecated
1921

2022

2123
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -632,6 +634,9 @@ def df(G):
632634
return G, L
633635

634636

637+
@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
638+
"removed in 0.5"
639+
"\n\tfor standard transport use class EMDTransport instead.")
635640
class OTDA(object):
636641

637642
"""Class for domain adaptation with optimal transport as proposed in [5]
@@ -758,10 +763,15 @@ def normalizeM(self, norm):
758763
self.M = np.log(1 + np.log(1 + self.M))
759764

760765

766+
@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
767+
" removed in 0.5 \nUse class SinkhornTransport instead.")
761768
class OTDA_sinkhorn(OTDA):
762769

763770
"""Class for domain adaptation with optimal transport with entropic
764-
regularization"""
771+
regularization
772+
773+
774+
"""
765775

766776
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
767777
"""Fit regularized domain adaptation between samples is xs and xt
@@ -783,6 +793,8 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
783793
self.computed = True
784794

785795

796+
@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
797+
" removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
786798
class OTDA_lpl1(OTDA):
787799

788800
"""Class for domain adaptation with optimal transport with entropic and
@@ -810,6 +822,8 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
810822
self.computed = True
811823

812824

825+
@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
826+
" removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
813827
class OTDA_l1l2(OTDA):
814828

815829
"""Class for domain adaptation with optimal transport with entropic
@@ -837,6 +851,8 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
837851
self.computed = True
838852

839853

854+
@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
855+
" removed in 0.5 \nUse class MappingTransport instead.")
840856
class OTDA_mapping_linear(OTDA):
841857

842858
"""Class for optimal transport with joint linear mapping estimation as in
@@ -882,6 +898,8 @@ def predict(self, x):
882898
return None
883899

884900

901+
@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
902+
" removed in 0.5 \nUse class MappingTransport instead.")
885903
class OTDA_mapping_kernel(OTDA_mapping_linear):
886904

887905
"""Class for optimal transport with joint nonlinear mapping

ot/deprecation.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
deprecated class from scikit-learn package
3+
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
4+
"""
5+
6+
import sys
7+
import warnings
8+
9+
__all__ = ["deprecated", ]
10+
11+
12+
class deprecated(object):
13+
"""Decorator to mark a function or class as deprecated.
14+
Issue a warning when the function is called/the class is instantiated and
15+
adds a warning to the docstring.
16+
The optional extra argument will be appended to the deprecation message
17+
and the docstring. Note: to use this with the default value for extra, put
18+
in an empty of parentheses:
19+
>>> from ot.deprecation import deprecated
20+
>>> @deprecated()
21+
... def some_function(): pass
22+
23+
Parameters
24+
----------
25+
extra : string
26+
to be added to the deprecation messages
27+
"""
28+
29+
# Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
30+
# but with many changes.
31+
32+
def __init__(self, extra=''):
33+
self.extra = extra
34+
35+
def __call__(self, obj):
36+
"""Call method
37+
Parameters
38+
----------
39+
obj : object
40+
"""
41+
if isinstance(obj, type):
42+
return self._decorate_class(obj)
43+
else:
44+
return self._decorate_fun(obj)
45+
46+
def _decorate_class(self, cls):
47+
msg = "Class %s is deprecated" % cls.__name__
48+
if self.extra:
49+
msg += "; %s" % self.extra
50+
51+
# FIXME: we should probably reset __new__ for full generality
52+
init = cls.__init__
53+
54+
def wrapped(*args, **kwargs):
55+
warnings.warn(msg, category=DeprecationWarning)
56+
return init(*args, **kwargs)
57+
58+
cls.__init__ = wrapped
59+
60+
wrapped.__name__ = '__init__'
61+
wrapped.__doc__ = self._update_doc(init.__doc__)
62+
wrapped.deprecated_original = init
63+
64+
return cls
65+
66+
def _decorate_fun(self, fun):
67+
"""Decorate function fun"""
68+
69+
msg = "Function %s is deprecated" % fun.__name__
70+
if self.extra:
71+
msg += "; %s" % self.extra
72+
73+
def wrapped(*args, **kwargs):
74+
warnings.warn(msg, category=DeprecationWarning)
75+
return fun(*args, **kwargs)
76+
77+
wrapped.__name__ = fun.__name__
78+
wrapped.__dict__ = fun.__dict__
79+
wrapped.__doc__ = self._update_doc(fun.__doc__)
80+
81+
return wrapped
82+
83+
def _update_doc(self, olddoc):
84+
newdoc = "DEPRECATED"
85+
if self.extra:
86+
newdoc = "%s: %s" % (newdoc, self.extra)
87+
if olddoc:
88+
newdoc = "%s\n\n%s" % (newdoc, olddoc)
89+
return newdoc
90+
91+
92+
def _is_deprecated(func):
93+
"""Helper to check if func is wraped by our deprecated decorator"""
94+
if sys.version_info < (3, 5):
95+
raise NotImplementedError("This is only available for python3.5 "
96+
"or above")
97+
closures = getattr(func, '__closure__', [])
98+
if closures is None:
99+
closures = []
100+
is_deprecated = ('deprecated' in ''.join([c.cell_contents
101+
for c in closures
102+
if isinstance(c.cell_contents, str)]))
103+
return is_deprecated

test/test_da.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,11 @@ def test_otda():
432432
da_emd.predict(xs) # interpolation of source samples
433433

434434

435-
if __name__ == "__main__":
435+
# if __name__ == "__main__":
436436

437+
# test_otda()
437438
# test_sinkhorn_transport_class()
438439
# test_emd_transport_class()
439440
# test_sinkhorn_l1l2_transport_class()
440441
# test_sinkhorn_lpl1_transport_class()
441-
test_mapping_transport_class()
442+
# test_mapping_transport_class()

0 commit comments

Comments
 (0)