Skip to content

Commit 0316d55

Browse files
committed
Move normalize function in utils.py
1 parent 5bbea9c commit 0316d55

File tree

2 files changed

+39
-46
lines changed

2 files changed

+39
-46
lines changed

ot/da.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .bregman import sinkhorn
1515
from .lp import emd
16-
from .utils import unif, dist, kernel
16+
from .utils import unif, dist, kernel, cost_normalization
1717
from .utils import check_params, deprecated, BaseEstimator
1818
from .optim import cg
1919
from .optim import gcg
@@ -673,7 +673,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
673673
self.wt = wt
674674

675675
self.M = dist(xs, xt, metric=self.metric)
676-
self.normalizeM(norm)
676+
self.M = cost_normalization(self.M, norm)
677677
self.G = emd(ws, wt, self.M, max_iter)
678678
self.computed = True
679679

@@ -741,26 +741,6 @@ def predict(self, x, direction=1):
741741
# aply the delta to the interpolation
742742
return xf[idx, :] + x - x0[idx, :]
743743

744-
def normalizeM(self, norm):
745-
""" Apply normalization to the loss matrix
746-
747-
748-
Parameters
749-
----------
750-
norm : str
751-
type of normalization from 'median','max','log','loglog'
752-
753-
"""
754-
755-
if norm == "median":
756-
self.M /= float(np.median(self.M))
757-
elif norm == "max":
758-
self.M /= float(np.max(self.M))
759-
elif norm == "log":
760-
self.M = np.log(1 + self.M)
761-
elif norm == "loglog":
762-
self.M = np.log(1 + np.log(1 + self.M))
763-
764744

765745
@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
766746
" removed in 0.5 \nUse class SinkhornTransport instead.")
@@ -787,7 +767,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
787767
self.wt = wt
788768

789769
self.M = dist(xs, xt, metric=self.metric)
790-
self.normalizeM(norm)
770+
self.M = cost_normalization(self.M, norm)
791771
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
792772
self.computed = True
793773

@@ -816,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
816796
self.wt = wt
817797

818798
self.M = dist(xs, xt, metric=self.metric)
819-
self.normalizeM(norm)
799+
self.M = cost_normalization(self.M, norm)
820800
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
821801
self.computed = True
822802

@@ -845,7 +825,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
845825
self.wt = wt
846826

847827
self.M = dist(xs, xt, metric=self.metric)
848-
self.normalizeM(norm)
828+
self.M = cost_normalization(self.M, norm)
849829
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
850830
self.computed = True
851831

@@ -1001,7 +981,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1001981

1002982
# pairwise distance
1003983
self.cost_ = dist(Xs, Xt, metric=self.metric)
1004-
self.normalizeCost_(self.norm)
984+
self.cost_ = cost_normalization(self.cost_, self.norm)
1005985

1006986
if (ys is not None) and (yt is not None):
1007987

@@ -1183,26 +1163,6 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11831163

11841164
return transp_Xt
11851165

1186-
def normalizeCost_(self, norm):
1187-
""" Apply normalization to the loss matrix
1188-
1189-
1190-
Parameters
1191-
----------
1192-
norm : str
1193-
type of normalization from 'median','max','log','loglog'
1194-
1195-
"""
1196-
1197-
if norm == "median":
1198-
self.cost_ /= float(np.median(self.cost_))
1199-
elif norm == "max":
1200-
self.cost_ /= float(np.max(self.cost_))
1201-
elif norm == "log":
1202-
self.cost_ = np.log(1 + self.cost_)
1203-
elif norm == "loglog":
1204-
self.cost_ = np.log(1 + np.log(1 + self.cost_))
1205-
12061166

12071167
class SinkhornTransport(BaseTransport):
12081168
"""Domain Adapatation OT method based on Sinkhorn Algorithm

ot/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,39 @@ def dist0(n, method='lin_square'):
134134
return res
135135

136136

137+
def cost_normalization(C, norm=None):
138+
""" Apply normalization to the loss matrix
139+
140+
141+
Parameters
142+
----------
143+
C : np.array (n1, n2)
144+
The cost matrix to normalize.
145+
norm : str
146+
type of normalization from 'median','max','log','loglog'. Any other
147+
value do not normalize.
148+
149+
150+
Returns
151+
-------
152+
153+
C : np.array (n1, n2)
154+
The input cost matrix normalized according to given norm.
155+
156+
"""
157+
158+
if norm == "median":
159+
C /= float(np.median(C))
160+
elif norm == "max":
161+
C /= float(np.max(C))
162+
elif norm == "log":
163+
C = np.log(1 + C)
164+
elif norm == "loglog":
165+
C = np.log(1 + np.log(1 + C))
166+
167+
return C
168+
169+
137170
def dots(*args):
138171
""" dots function for multiple matrix multiply """
139172
return reduce(np.dot, args)

0 commit comments

Comments
 (0)