Skip to content

Commit 06fab4c

Browse files
committed
more
1 parent b6fb148 commit 06fab4c

File tree

4 files changed

+80
-102
lines changed

4 files changed

+80
-102
lines changed

ot/bregman.py

Lines changed: 51 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
4040
4141
Parameters
4242
----------
43-
a : np.ndarray (ns,)
43+
a : ndarray, shape (ns,)
4444
samples weights in the source domain
45-
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
45+
b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
4646
samples in the target domain, compute sinkhorn with multiple targets
4747
and fixed M if b is a matrix (return OT loss + dual variables in log)
48-
M : np.ndarray (ns,nt)
48+
M : ndarray, shape (ns, nt)
4949
loss matrix
5050
reg : float
5151
Regularization term >0
@@ -64,7 +64,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
6464
6565
Returns
6666
-------
67-
gamma : (ns x nt) ndarray
67+
gamma : ndarray, shape (ns, nt)
6868
Optimal transportation matrix for the given parameters
6969
log : dict
7070
log dictionary return only if log==True in parameters
@@ -155,12 +155,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
155155
156156
Parameters
157157
----------
158-
a : np.ndarray (ns,)
158+
a : ndarray, shape (ns,)
159159
samples weights in the source domain
160-
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
160+
b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
161161
samples in the target domain, compute sinkhorn with multiple targets
162162
and fixed M if b is a matrix (return OT loss + dual variables in log)
163-
M : np.ndarray (ns,nt)
163+
M : ndarray, shape (ns, nt)
164164
loss matrix
165165
reg : float
166166
Regularization term >0
@@ -176,7 +176,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
176176
log : bool, optional
177177
record log if True
178178
179-
180179
Returns
181180
-------
182181
W : (nt) ndarray or float
@@ -272,12 +271,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
272271
273272
Parameters
274273
----------
275-
a : np.ndarray (ns,)
274+
a : ndarray, shape (ns,)
276275
samples weights in the source domain
277-
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
276+
b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
278277
samples in the target domain, compute sinkhorn with multiple targets
279278
and fixed M if b is a matrix (return OT loss + dual variables in log)
280-
M : np.ndarray (ns,nt)
279+
M : ndarray, shape (ns, nt)
281280
loss matrix
282281
reg : float
283282
Regularization term >0
@@ -290,10 +289,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
290289
log : bool, optional
291290
record log if True
292291
293-
294292
Returns
295293
-------
296-
gamma : (ns x nt) ndarray
294+
gamma : ndarray, shape (ns, nt)
297295
Optimal transportation matrix for the given parameters
298296
log : dict
299297
log dictionary return only if log==True in parameters
@@ -453,12 +451,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
453451
454452
Parameters
455453
----------
456-
a : np.ndarray (ns,)
454+
a : ndarray, shape (ns,)
457455
samples weights in the source domain
458-
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
456+
b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
459457
samples in the target domain, compute sinkhorn with multiple targets
460458
and fixed M if b is a matrix (return OT loss + dual variables in log)
461-
M : np.ndarray (ns,nt)
459+
M : ndarray, shape (ns, nt)
462460
loss matrix
463461
reg : float
464462
Regularization term >0
@@ -469,10 +467,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
469467
log : bool, optional
470468
record log if True
471469
472-
473470
Returns
474471
-------
475-
gamma : (ns x nt) ndarray
472+
gamma : ndarray, shape (ns, nt)
476473
Optimal transportation matrix for the given parameters
477474
log : dict
478475
log dictionary return only if log==True in parameters
@@ -602,11 +599,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
602599
603600
Parameters
604601
----------
605-
a : np.ndarray (ns,)
602+
a : ndarray, shape (ns,)
606603
samples weights in the source domain
607-
b : np.ndarray (nt,)
604+
b : ndarray, shape (nt,)
608605
samples in the target domain
609-
M : np.ndarray (ns,nt)
606+
M : ndarray, shape (ns, nt)
610607
loss matrix
611608
reg : float
612609
Regularization term >0
@@ -623,10 +620,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
623620
log : bool, optional
624621
record log if True
625622
626-
627623
Returns
628624
-------
629-
gamma : (ns x nt) ndarray
625+
gamma : ndarray, shape (ns, nt)
630626
Optimal transportation matrix for the given parameters
631627
log : dict
632628
log dictionary return only if log==True in parameters
@@ -823,19 +819,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
823819
824820
Parameters
825821
----------
826-
a : np.ndarray (ns,)
822+
a : ndarray, shape (ns,)
827823
samples weights in the source domain
828-
b : np.ndarray (nt,)
824+
b : ndarray, shape (nt,)
829825
samples in the target domain
830-
M : np.ndarray (ns,nt)
826+
M : ndarray, shape (ns, nt)
831827
loss matrix
832828
reg : float
833829
Regularization term >0
834830
tau : float
835831
thershold for max value in u or v for log scaling
836832
tau : float
837833
thershold for max value in u or v for log scaling
838-
warmstart : tible of vectors
834+
warmstart : tuple of vectors
839835
if given then sarting values for alpha an beta log scalings
840836
numItermax : int, optional
841837
Max number of iterations
@@ -850,10 +846,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
850846
log : bool, optional
851847
record log if True
852848
853-
854849
Returns
855850
-------
856-
gamma : (ns x nt) ndarray
851+
gamma : ndarray, shape (ns, nt)
857852
Optimal transportation matrix for the given parameters
858853
log : dict
859854
log dictionary return only if log==True in parameters
@@ -1006,13 +1001,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
10061001
10071002
Parameters
10081003
----------
1009-
A : np.ndarray (d,n)
1004+
A : ndarray, shape (d,n)
10101005
n training distributions a_i of size d
1011-
M : np.ndarray (d,d)
1006+
M : ndarray, shape (d,d)
10121007
loss matrix for OT
10131008
reg : float
10141009
Regularization term >0
1015-
weights : np.ndarray (n,)
1010+
weights : ndarray, shape (n,)
10161011
Weights of each histogram a_i on the simplex (barycentric coodinates)
10171012
numItermax : int, optional
10181013
Max number of iterations
@@ -1102,11 +1097,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
11021097
11031098
Parameters
11041099
----------
1105-
A : np.ndarray (n,w,h)
1100+
A : ndarray, shape (n, w, h)
11061101
n distributions (2D images) of size w x h
11071102
reg : float
11081103
Regularization term >0
1109-
weights : np.ndarray (n,)
1104+
weights : ndarray, shape (n,)
11101105
Weights of each image on the simplex (barycentric coodinates)
11111106
numItermax : int, optional
11121107
Max number of iterations
@@ -1119,15 +1114,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
11191114
log : bool, optional
11201115
record log if True
11211116
1122-
11231117
Returns
11241118
-------
1125-
a : (w,h) ndarray
1119+
a : ndarray, shape (w, h)
11261120
2D Wasserstein barycenter
11271121
log : dict
11281122
log dictionary return only if log==True in parameters
11291123
1130-
11311124
References
11321125
----------
11331126
@@ -1217,15 +1210,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
12171210
12181211
Parameters
12191212
----------
1220-
a : np.ndarray (d)
1213+
a : ndarray, shape (d)
12211214
observed distribution
1222-
D : np.ndarray (d,n)
1215+
D : ndarray, shape (d, n)
12231216
dictionary matrix
1224-
M : np.ndarray (d,d)
1217+
M : ndarray, shape (d, d)
12251218
loss matrix
1226-
M0 : np.ndarray (n,n)
1219+
M0 : ndarray, shape (n, n)
12271220
loss matrix
1228-
h0 : np.ndarray (n,)
1221+
h0 : ndarray, shape (n,)
12291222
prior on h
12301223
reg : float
12311224
Regularization term >0 (Wasserstein data fitting)
@@ -1245,7 +1238,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
12451238
12461239
Returns
12471240
-------
1248-
a : (d,) ndarray
1241+
a : ndarray, shape (d,)
12491242
Wasserstein barycenter
12501243
log : dict
12511244
log dictionary return only if log==True in parameters
@@ -1325,15 +1318,15 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
13251318
13261319
Parameters
13271320
----------
1328-
X_s : np.ndarray (ns, d)
1321+
X_s : ndarray, shape (ns, d)
13291322
samples in the source domain
1330-
X_t : np.ndarray (nt, d)
1323+
X_t : ndarray, shape (nt, d)
13311324
samples in the target domain
13321325
reg : float
13331326
Regularization term >0
1334-
a : np.ndarray (ns,)
1327+
a : ndarray, shape (ns,)
13351328
samples weights in the source domain
1336-
b : np.ndarray (nt,)
1329+
b : ndarray, shape (nt,)
13371330
samples weights in the target domain
13381331
numItermax : int, optional
13391332
Max number of iterations
@@ -1347,7 +1340,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
13471340
13481341
Returns
13491342
-------
1350-
gamma : (ns x nt) ndarray
1343+
gamma : ndarray, shape (ns, nt)
13511344
Regularized optimal transportation matrix for the given parameters
13521345
log : dict
13531346
log dictionary return only if log==True in parameters
@@ -1415,15 +1408,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
14151408
14161409
Parameters
14171410
----------
1418-
X_s : np.ndarray (ns, d)
1411+
X_s : ndarray, shape (ns, d)
14191412
samples in the source domain
1420-
X_t : np.ndarray (nt, d)
1413+
X_t : ndarray, shape (nt, d)
14211414
samples in the target domain
14221415
reg : float
14231416
Regularization term >0
1424-
a : np.ndarray (ns,)
1417+
a : ndarray, shape (ns,)
14251418
samples weights in the source domain
1426-
b : np.ndarray (nt,)
1419+
b : ndarray, shape (nt,)
14271420
samples weights in the target domain
14281421
numItermax : int, optional
14291422
Max number of iterations
@@ -1437,7 +1430,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
14371430
14381431
Returns
14391432
-------
1440-
gamma : (ns x nt) ndarray
1433+
gamma : ndarray, shape (ns, nt)
14411434
Regularized optimal transportation matrix for the given parameters
14421435
log : dict
14431436
log dictionary return only if log==True in parameters
@@ -1523,15 +1516,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15231516
15241517
Parameters
15251518
----------
1526-
X_s : np.ndarray (ns, d)
1519+
X_s : ndarray, shape (ns, d)
15271520
samples in the source domain
1528-
X_t : np.ndarray (nt, d)
1521+
X_t : ndarray, shape (nt, d)
15291522
samples in the target domain
15301523
reg : float
15311524
Regularization term >0
1532-
a : np.ndarray (ns,)
1525+
a : ndarray, shape (ns,)
15331526
samples weights in the source domain
1534-
b : np.ndarray (nt,)
1527+
b : ndarray, shape (nt,)
15351528
samples weights in the target domain
15361529
numItermax : int, optional
15371530
Max number of iterations
@@ -1542,17 +1535,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15421535
log : bool, optional
15431536
record log if True
15441537
1545-
15461538
Returns
15471539
-------
1548-
gamma : (ns x nt) ndarray
1540+
gamma : ndarray, shape (ns, nt)
15491541
Regularized optimal transportation matrix for the given parameters
15501542
log : dict
15511543
log dictionary return only if log==True in parameters
15521544
15531545
Examples
15541546
--------
1555-
15561547
>>> n_s = 2
15571548
>>> n_t = 4
15581549
>>> reg = 0.1
@@ -1564,7 +1555,6 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15641555
15651556
References
15661557
----------
1567-
15681558
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
15691559
'''
15701560
if log:

0 commit comments

Comments
 (0)