@@ -40,12 +40,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
40
40
41
41
Parameters
42
42
----------
43
- a : np. ndarray (ns,)
43
+ a : ndarray, shape (ns,)
44
44
samples weights in the source domain
45
- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
45
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
46
46
samples in the target domain, compute sinkhorn with multiple targets
47
47
and fixed M if b is a matrix (return OT loss + dual variables in log)
48
- M : np. ndarray (ns,nt)
48
+ M : ndarray, shape (ns, nt)
49
49
loss matrix
50
50
reg : float
51
51
Regularization term >0
@@ -64,7 +64,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
64
64
65
65
Returns
66
66
-------
67
- gamma : (ns x nt) ndarray
67
+ gamma : ndarray, shape (ns, nt)
68
68
Optimal transportation matrix for the given parameters
69
69
log : dict
70
70
log dictionary return only if log==True in parameters
@@ -155,12 +155,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
155
155
156
156
Parameters
157
157
----------
158
- a : np. ndarray (ns,)
158
+ a : ndarray, shape (ns,)
159
159
samples weights in the source domain
160
- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
160
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
161
161
samples in the target domain, compute sinkhorn with multiple targets
162
162
and fixed M if b is a matrix (return OT loss + dual variables in log)
163
- M : np. ndarray (ns,nt)
163
+ M : ndarray, shape (ns, nt)
164
164
loss matrix
165
165
reg : float
166
166
Regularization term >0
@@ -176,7 +176,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
176
176
log : bool, optional
177
177
record log if True
178
178
179
-
180
179
Returns
181
180
-------
182
181
W : (nt) ndarray or float
@@ -272,12 +271,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
272
271
273
272
Parameters
274
273
----------
275
- a : np. ndarray (ns,)
274
+ a : ndarray, shape (ns,)
276
275
samples weights in the source domain
277
- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
276
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
278
277
samples in the target domain, compute sinkhorn with multiple targets
279
278
and fixed M if b is a matrix (return OT loss + dual variables in log)
280
- M : np. ndarray (ns,nt)
279
+ M : ndarray, shape (ns, nt)
281
280
loss matrix
282
281
reg : float
283
282
Regularization term >0
@@ -290,10 +289,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
290
289
log : bool, optional
291
290
record log if True
292
291
293
-
294
292
Returns
295
293
-------
296
- gamma : (ns x nt) ndarray
294
+ gamma : ndarray, shape (ns, nt)
297
295
Optimal transportation matrix for the given parameters
298
296
log : dict
299
297
log dictionary return only if log==True in parameters
@@ -453,12 +451,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
453
451
454
452
Parameters
455
453
----------
456
- a : np. ndarray (ns,)
454
+ a : ndarray, shape (ns,)
457
455
samples weights in the source domain
458
- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
456
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
459
457
samples in the target domain, compute sinkhorn with multiple targets
460
458
and fixed M if b is a matrix (return OT loss + dual variables in log)
461
- M : np. ndarray (ns,nt)
459
+ M : ndarray, shape (ns, nt)
462
460
loss matrix
463
461
reg : float
464
462
Regularization term >0
@@ -469,10 +467,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
469
467
log : bool, optional
470
468
record log if True
471
469
472
-
473
470
Returns
474
471
-------
475
- gamma : (ns x nt) ndarray
472
+ gamma : ndarray, shape (ns, nt)
476
473
Optimal transportation matrix for the given parameters
477
474
log : dict
478
475
log dictionary return only if log==True in parameters
@@ -602,11 +599,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
602
599
603
600
Parameters
604
601
----------
605
- a : np. ndarray (ns,)
602
+ a : ndarray, shape (ns,)
606
603
samples weights in the source domain
607
- b : np. ndarray (nt,)
604
+ b : ndarray, shape (nt,)
608
605
samples in the target domain
609
- M : np. ndarray (ns,nt)
606
+ M : ndarray, shape (ns, nt)
610
607
loss matrix
611
608
reg : float
612
609
Regularization term >0
@@ -623,10 +620,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
623
620
log : bool, optional
624
621
record log if True
625
622
626
-
627
623
Returns
628
624
-------
629
- gamma : (ns x nt) ndarray
625
+ gamma : ndarray, shape (ns, nt)
630
626
Optimal transportation matrix for the given parameters
631
627
log : dict
632
628
log dictionary return only if log==True in parameters
@@ -823,19 +819,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
823
819
824
820
Parameters
825
821
----------
826
- a : np. ndarray (ns,)
822
+ a : ndarray, shape (ns,)
827
823
samples weights in the source domain
828
- b : np. ndarray (nt,)
824
+ b : ndarray, shape (nt,)
829
825
samples in the target domain
830
- M : np. ndarray (ns,nt)
826
+ M : ndarray, shape (ns, nt)
831
827
loss matrix
832
828
reg : float
833
829
Regularization term >0
834
830
tau : float
835
831
thershold for max value in u or v for log scaling
836
832
tau : float
837
833
thershold for max value in u or v for log scaling
838
- warmstart : tible of vectors
834
+ warmstart : tuple of vectors
839
835
if given then sarting values for alpha an beta log scalings
840
836
numItermax : int, optional
841
837
Max number of iterations
@@ -850,10 +846,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
850
846
log : bool, optional
851
847
record log if True
852
848
853
-
854
849
Returns
855
850
-------
856
- gamma : (ns x nt) ndarray
851
+ gamma : ndarray, shape (ns, nt)
857
852
Optimal transportation matrix for the given parameters
858
853
log : dict
859
854
log dictionary return only if log==True in parameters
@@ -1006,13 +1001,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
1006
1001
1007
1002
Parameters
1008
1003
----------
1009
- A : np. ndarray (d,n)
1004
+ A : ndarray, shape (d,n)
1010
1005
n training distributions a_i of size d
1011
- M : np. ndarray (d,d)
1006
+ M : ndarray, shape (d,d)
1012
1007
loss matrix for OT
1013
1008
reg : float
1014
1009
Regularization term >0
1015
- weights : np. ndarray (n,)
1010
+ weights : ndarray, shape (n,)
1016
1011
Weights of each histogram a_i on the simplex (barycentric coodinates)
1017
1012
numItermax : int, optional
1018
1013
Max number of iterations
@@ -1102,11 +1097,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
1102
1097
1103
1098
Parameters
1104
1099
----------
1105
- A : np. ndarray (n,w, h)
1100
+ A : ndarray, shape (n, w, h)
1106
1101
n distributions (2D images) of size w x h
1107
1102
reg : float
1108
1103
Regularization term >0
1109
- weights : np. ndarray (n,)
1104
+ weights : ndarray, shape (n,)
1110
1105
Weights of each image on the simplex (barycentric coodinates)
1111
1106
numItermax : int, optional
1112
1107
Max number of iterations
@@ -1119,15 +1114,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
1119
1114
log : bool, optional
1120
1115
record log if True
1121
1116
1122
-
1123
1117
Returns
1124
1118
-------
1125
- a : (w,h) ndarray
1119
+ a : ndarray, shape (w, h)
1126
1120
2D Wasserstein barycenter
1127
1121
log : dict
1128
1122
log dictionary return only if log==True in parameters
1129
1123
1130
-
1131
1124
References
1132
1125
----------
1133
1126
@@ -1217,15 +1210,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
1217
1210
1218
1211
Parameters
1219
1212
----------
1220
- a : np. ndarray (d)
1213
+ a : ndarray, shape (d)
1221
1214
observed distribution
1222
- D : np. ndarray (d,n)
1215
+ D : ndarray, shape (d, n)
1223
1216
dictionary matrix
1224
- M : np. ndarray (d,d)
1217
+ M : ndarray, shape (d, d)
1225
1218
loss matrix
1226
- M0 : np. ndarray (n,n)
1219
+ M0 : ndarray, shape (n, n)
1227
1220
loss matrix
1228
- h0 : np. ndarray (n,)
1221
+ h0 : ndarray, shape (n,)
1229
1222
prior on h
1230
1223
reg : float
1231
1224
Regularization term >0 (Wasserstein data fitting)
@@ -1245,7 +1238,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
1245
1238
1246
1239
Returns
1247
1240
-------
1248
- a : (d,) ndarray
1241
+ a : ndarray, shape (d,)
1249
1242
Wasserstein barycenter
1250
1243
log : dict
1251
1244
log dictionary return only if log==True in parameters
@@ -1325,15 +1318,15 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
1325
1318
1326
1319
Parameters
1327
1320
----------
1328
- X_s : np. ndarray (ns, d)
1321
+ X_s : ndarray, shape (ns, d)
1329
1322
samples in the source domain
1330
- X_t : np. ndarray (nt, d)
1323
+ X_t : ndarray, shape (nt, d)
1331
1324
samples in the target domain
1332
1325
reg : float
1333
1326
Regularization term >0
1334
- a : np. ndarray (ns,)
1327
+ a : ndarray, shape (ns,)
1335
1328
samples weights in the source domain
1336
- b : np. ndarray (nt,)
1329
+ b : ndarray, shape (nt,)
1337
1330
samples weights in the target domain
1338
1331
numItermax : int, optional
1339
1332
Max number of iterations
@@ -1347,7 +1340,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
1347
1340
1348
1341
Returns
1349
1342
-------
1350
- gamma : (ns x nt) ndarray
1343
+ gamma : ndarray, shape (ns, nt)
1351
1344
Regularized optimal transportation matrix for the given parameters
1352
1345
log : dict
1353
1346
log dictionary return only if log==True in parameters
@@ -1415,15 +1408,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
1415
1408
1416
1409
Parameters
1417
1410
----------
1418
- X_s : np. ndarray (ns, d)
1411
+ X_s : ndarray, shape (ns, d)
1419
1412
samples in the source domain
1420
- X_t : np. ndarray (nt, d)
1413
+ X_t : ndarray, shape (nt, d)
1421
1414
samples in the target domain
1422
1415
reg : float
1423
1416
Regularization term >0
1424
- a : np. ndarray (ns,)
1417
+ a : ndarray, shape (ns,)
1425
1418
samples weights in the source domain
1426
- b : np. ndarray (nt,)
1419
+ b : ndarray, shape (nt,)
1427
1420
samples weights in the target domain
1428
1421
numItermax : int, optional
1429
1422
Max number of iterations
@@ -1437,7 +1430,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
1437
1430
1438
1431
Returns
1439
1432
-------
1440
- gamma : (ns x nt) ndarray
1433
+ gamma : ndarray, shape (ns, nt)
1441
1434
Regularized optimal transportation matrix for the given parameters
1442
1435
log : dict
1443
1436
log dictionary return only if log==True in parameters
@@ -1523,15 +1516,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
1523
1516
1524
1517
Parameters
1525
1518
----------
1526
- X_s : np. ndarray (ns, d)
1519
+ X_s : ndarray, shape (ns, d)
1527
1520
samples in the source domain
1528
- X_t : np. ndarray (nt, d)
1521
+ X_t : ndarray, shape (nt, d)
1529
1522
samples in the target domain
1530
1523
reg : float
1531
1524
Regularization term >0
1532
- a : np. ndarray (ns,)
1525
+ a : ndarray, shape (ns,)
1533
1526
samples weights in the source domain
1534
- b : np. ndarray (nt,)
1527
+ b : ndarray, shape (nt,)
1535
1528
samples weights in the target domain
1536
1529
numItermax : int, optional
1537
1530
Max number of iterations
@@ -1542,17 +1535,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
1542
1535
log : bool, optional
1543
1536
record log if True
1544
1537
1545
-
1546
1538
Returns
1547
1539
-------
1548
- gamma : (ns x nt) ndarray
1540
+ gamma : ndarray, shape (ns, nt)
1549
1541
Regularized optimal transportation matrix for the given parameters
1550
1542
log : dict
1551
1543
log dictionary return only if log==True in parameters
1552
1544
1553
1545
Examples
1554
1546
--------
1555
-
1556
1547
>>> n_s = 2
1557
1548
>>> n_t = 4
1558
1549
>>> reg = 0.1
@@ -1564,7 +1555,6 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
1564
1555
1565
1556
References
1566
1557
----------
1567
-
1568
1558
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
1569
1559
'''
1570
1560
if log :
0 commit comments