Skip to content

Commit 5180023

Browse files
authored
Merge pull request #48 from rflamary/remove_otda_v05
Remove deprecated OTDA Classes
2 parents 9351bfa + f5dcbc4 commit 5180023

File tree

2 files changed

+1
-346
lines changed

2 files changed

+1
-346
lines changed

ot/da.py

Lines changed: 1 addition & 283 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .bregman import sinkhorn
1616
from .lp import emd
1717
from .utils import unif, dist, kernel, cost_normalization
18-
from .utils import check_params, deprecated, BaseEstimator
18+
from .utils import check_params, BaseEstimator
1919
from .optim import cg
2020
from .optim import gcg
2121

@@ -740,288 +740,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
740740
return A, b
741741

742742

743-
@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
744-
"removed in 0.5"
745-
"\n\tfor standard transport use class EMDTransport instead.")
746-
class OTDA(object):
747-
748-
"""Class for domain adaptation with optimal transport as proposed in [5]
749-
750-
751-
References
752-
----------
753-
754-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
755-
"Optimal Transport for Domain Adaptation," in IEEE Transactions on
756-
Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
757-
758-
"""
759-
760-
def __init__(self, metric='sqeuclidean', norm=None):
761-
""" Class initialization"""
762-
self.xs = 0
763-
self.xt = 0
764-
self.G = 0
765-
self.metric = metric
766-
self.norm = norm
767-
self.computed = False
768-
769-
def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
770-
"""Fit domain adaptation between samples is xs and xt
771-
(with optional weights)"""
772-
self.xs = xs
773-
self.xt = xt
774-
775-
if wt is None:
776-
wt = unif(xt.shape[0])
777-
if ws is None:
778-
ws = unif(xs.shape[0])
779-
780-
self.ws = ws
781-
self.wt = wt
782-
783-
self.M = dist(xs, xt, metric=self.metric)
784-
self.M = cost_normalization(self.M, self.norm)
785-
self.G = emd(ws, wt, self.M, max_iter)
786-
self.computed = True
787-
788-
def interp(self, direction=1):
789-
"""Barycentric interpolation for the source (1) or target (-1) samples
790-
791-
This Barycentric interpolation solves for each source (resp target)
792-
sample xs (resp xt) the following optimization problem:
793-
794-
.. math::
795-
arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
796-
797-
where k is the index of the sample in xs
798-
799-
For the moment only squared euclidean distance is provided but more
800-
metric could be used in the future.
801-
802-
"""
803-
if direction > 0: # >0 then source to target
804-
G = self.G
805-
w = self.ws.reshape((self.xs.shape[0], 1))
806-
x = self.xt
807-
else:
808-
G = self.G.T
809-
w = self.wt.reshape((self.xt.shape[0], 1))
810-
x = self.xs
811-
812-
if self.computed:
813-
if self.metric == 'sqeuclidean':
814-
return np.dot(G / w, x) # weighted mean
815-
else:
816-
print(
817-
"Warning, metric not handled yet, using weighted average")
818-
return np.dot(G / w, x) # weighted mean
819-
return None
820-
else:
821-
print("Warning, model not fitted yet, returning None")
822-
return None
823-
824-
def predict(self, x, direction=1):
825-
""" Out of sample mapping using the formulation from [6]
826-
827-
For each sample x to map, it finds the nearest source sample xs and
828-
map the samle x to the position xst+(x-xs) wher xst is the barycentric
829-
interpolation of source sample xs.
830-
831-
References
832-
----------
833-
834-
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
835-
Regularized discrete optimal transport. SIAM Journal on Imaging
836-
Sciences, 7(3), 1853-1882.
837-
838-
"""
839-
if direction > 0: # >0 then source to target
840-
xf = self.xt
841-
x0 = self.xs
842-
else:
843-
xf = self.xs
844-
x0 = self.xt
845-
846-
D0 = dist(x, x0) # dist netween new samples an source
847-
idx = np.argmin(D0, 1) # closest one
848-
xf = self.interp(direction) # interp the source samples
849-
# aply the delta to the interpolation
850-
return xf[idx, :] + x - x0[idx, :]
851-
852-
853-
@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
854-
" removed in 0.5 \nUse class SinkhornTransport instead.")
855-
class OTDA_sinkhorn(OTDA):
856-
857-
"""Class for domain adaptation with optimal transport with entropic
858-
regularization
859-
860-
861-
"""
862-
863-
def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
864-
"""Fit regularized domain adaptation between samples is xs and xt
865-
(with optional weights)"""
866-
self.xs = xs
867-
self.xt = xt
868-
869-
if wt is None:
870-
wt = unif(xt.shape[0])
871-
if ws is None:
872-
ws = unif(xs.shape[0])
873-
874-
self.ws = ws
875-
self.wt = wt
876-
877-
self.M = dist(xs, xt, metric=self.metric)
878-
self.M = cost_normalization(self.M, self.norm)
879-
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
880-
self.computed = True
881-
882-
883-
@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
884-
" removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
885-
class OTDA_lpl1(OTDA):
886-
887-
"""Class for domain adaptation with optimal transport with entropic and
888-
group regularization"""
889-
890-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
891-
"""Fit regularized domain adaptation between samples is xs and xt
892-
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
893-
parameters"""
894-
self.xs = xs
895-
self.xt = xt
896-
897-
if wt is None:
898-
wt = unif(xt.shape[0])
899-
if ws is None:
900-
ws = unif(xs.shape[0])
901-
902-
self.ws = ws
903-
self.wt = wt
904-
905-
self.M = dist(xs, xt, metric=self.metric)
906-
self.M = cost_normalization(self.M, self.norm)
907-
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
908-
self.computed = True
909-
910-
911-
@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
912-
" removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
913-
class OTDA_l1l2(OTDA):
914-
915-
"""Class for domain adaptation with optimal transport with entropic
916-
and group lasso regularization"""
917-
918-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
919-
"""Fit regularized domain adaptation between samples is xs and xt
920-
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
921-
parameters"""
922-
self.xs = xs
923-
self.xt = xt
924-
925-
if wt is None:
926-
wt = unif(xt.shape[0])
927-
if ws is None:
928-
ws = unif(xs.shape[0])
929-
930-
self.ws = ws
931-
self.wt = wt
932-
933-
self.M = dist(xs, xt, metric=self.metric)
934-
self.M = cost_normalization(self.M, self.norm)
935-
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
936-
self.computed = True
937-
938-
939-
@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
940-
" removed in 0.5 \nUse class MappingTransport instead.")
941-
class OTDA_mapping_linear(OTDA):
942-
943-
"""Class for optimal transport with joint linear mapping estimation as in
944-
[8]
945-
"""
946-
947-
def __init__(self):
948-
""" Class initialization"""
949-
950-
self.xs = 0
951-
self.xt = 0
952-
self.G = 0
953-
self.L = 0
954-
self.bias = False
955-
self.computed = False
956-
self.metric = 'sqeuclidean'
957-
958-
def fit(self, xs, xt, mu=1, eta=1, bias=False, **kwargs):
959-
""" Fit domain adaptation between samples is xs and xt (with optional
960-
weights)"""
961-
self.xs = xs
962-
self.xt = xt
963-
self.bias = bias
964-
965-
self.ws = unif(xs.shape[0])
966-
self.wt = unif(xt.shape[0])
967-
968-
self.G, self.L = joint_OT_mapping_linear(
969-
xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
970-
self.computed = True
971-
972-
def mapping(self):
973-
return lambda x: self.predict(x)
974-
975-
def predict(self, x):
976-
""" Out of sample mapping estimated during the call to fit"""
977-
if self.computed:
978-
if self.bias:
979-
x = np.hstack((x, np.ones((x.shape[0], 1))))
980-
return x.dot(self.L) # aply the delta to the interpolation
981-
else:
982-
print("Warning, model not fitted yet, returning None")
983-
return None
984-
985-
986-
@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
987-
" removed in 0.5 \nUse class MappingTransport instead.")
988-
class OTDA_mapping_kernel(OTDA_mapping_linear):
989-
990-
"""Class for optimal transport with joint nonlinear mapping
991-
estimation as in [8]"""
992-
993-
def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
994-
sigma=1, **kwargs):
995-
""" Fit domain adaptation between samples is xs and xt """
996-
self.xs = xs
997-
self.xt = xt
998-
self.bias = bias
999-
1000-
self.ws = unif(xs.shape[0])
1001-
self.wt = unif(xt.shape[0])
1002-
self.kernel = kerneltype
1003-
self.sigma = sigma
1004-
self.kwargs = kwargs
1005-
1006-
self.G, self.L = joint_OT_mapping_kernel(
1007-
xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
1008-
self.computed = True
1009-
1010-
def predict(self, x):
1011-
""" Out of sample mapping estimated during the call to fit"""
1012-
1013-
if self.computed:
1014-
K = kernel(
1015-
x, self.xs, method=self.kernel, sigma=self.sigma,
1016-
**self.kwargs)
1017-
if self.bias:
1018-
K = np.hstack((K, np.ones((x.shape[0], 1))))
1019-
return K.dot(self.L)
1020-
else:
1021-
print("Warning, model not fitted yet, returning None")
1022-
return None
1023-
1024-
1025743
def distribution_estimation_uniform(X):
1026744
"""estimates a uniform distribution from an array of samples X
1027745

test/test_da.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -484,66 +484,3 @@ def test_linear_mapping_class():
484484
Cst = np.cov(Xst.T)
485485

486486
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
487-
488-
489-
def test_otda():
490-
491-
n_samples = 150 # nb samples
492-
np.random.seed(0)
493-
494-
xs, ys = ot.datasets.make_data_classif('3gauss', n_samples)
495-
xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples)
496-
497-
a, b = ot.unif(n_samples), ot.unif(n_samples)
498-
499-
# LP problem
500-
da_emd = ot.da.OTDA() # init class
501-
da_emd.fit(xs, xt) # fit distributions
502-
da_emd.interp() # interpolation of source samples
503-
da_emd.predict(xs) # interpolation of source samples
504-
505-
np.testing.assert_allclose(a, np.sum(da_emd.G, 1))
506-
np.testing.assert_allclose(b, np.sum(da_emd.G, 0))
507-
508-
# sinkhorn regularization
509-
lambd = 1e-1
510-
da_entrop = ot.da.OTDA_sinkhorn()
511-
da_entrop.fit(xs, xt, reg=lambd)
512-
da_entrop.interp()
513-
da_entrop.predict(xs)
514-
515-
np.testing.assert_allclose(
516-
a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
517-
np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
518-
519-
# non-convex Group lasso regularization
520-
reg = 1e-1
521-
eta = 1e0
522-
da_lpl1 = ot.da.OTDA_lpl1()
523-
da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta)
524-
da_lpl1.interp()
525-
da_lpl1.predict(xs)
526-
527-
np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
528-
np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
529-
530-
# True Group lasso regularization
531-
reg = 1e-1
532-
eta = 2e0
533-
da_l1l2 = ot.da.OTDA_l1l2()
534-
da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True)
535-
da_l1l2.interp()
536-
da_l1l2.predict(xs)
537-
538-
np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
539-
np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
540-
541-
# linear mapping
542-
da_emd = ot.da.OTDA_mapping_linear() # init class
543-
da_emd.fit(xs, xt, numItermax=10) # fit distributions
544-
da_emd.predict(xs) # interpolation of source samples
545-
546-
# nonlinear mapping
547-
da_emd = ot.da.OTDA_mapping_kernel() # init class
548-
da_emd.fit(xs, xt, numItermax=10) # fit distributions
549-
da_emd.predict(xs) # interpolation of source samples

0 commit comments

Comments
 (0)