|
5 | 5 |
|
6 | 6 | # Author: Remi Flamary <[email protected]>
|
7 | 7 | # Nicolas Courty <[email protected]>
|
| 8 | +# Kilian Fatras <[email protected]> |
8 | 9 | #
|
9 | 10 | # License: MIT License
|
10 | 11 |
|
11 | 12 | import numpy as np
|
| 13 | +from .utils import unif, dist |
12 | 14 |
|
13 | 15 |
|
14 | 16 | def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
|
@@ -1296,3 +1298,302 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
|
1296 | 1298 | return np.sum(K0, axis=1), log
|
1297 | 1299 | else:
|
1298 | 1300 | return np.sum(K0, axis=1)
|
| 1301 | + |
| 1302 | + |
| 1303 | +def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): |
| 1304 | + ''' |
| 1305 | + Solve the entropic regularization optimal transport problem and return the |
| 1306 | + OT matrix from empirical data |
| 1307 | +
|
| 1308 | + The function solves the following optimization problem: |
| 1309 | +
|
| 1310 | + .. math:: |
| 1311 | + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) |
| 1312 | +
|
| 1313 | + s.t. \gamma 1 = a |
| 1314 | +
|
| 1315 | + \gamma^T 1= b |
| 1316 | +
|
| 1317 | + \gamma\geq 0 |
| 1318 | + where : |
| 1319 | +
|
| 1320 | + - :math:`M` is the (ns,nt) metric cost matrix |
| 1321 | + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` |
| 1322 | + - :math:`a` and :math:`b` are source and target weights (sum to 1) |
| 1323 | +
|
| 1324 | +
|
| 1325 | + Parameters |
| 1326 | + ---------- |
| 1327 | + X_s : np.ndarray (ns, d) |
| 1328 | + samples in the source domain |
| 1329 | + X_t : np.ndarray (nt, d) |
| 1330 | + samples in the target domain |
| 1331 | + reg : float |
| 1332 | + Regularization term >0 |
| 1333 | + a : np.ndarray (ns,) |
| 1334 | + samples weights in the source domain |
| 1335 | + b : np.ndarray (nt,) |
| 1336 | + samples weights in the target domain |
| 1337 | + numItermax : int, optional |
| 1338 | + Max number of iterations |
| 1339 | + stopThr : float, optional |
| 1340 | + Stop threshol on error (>0) |
| 1341 | + verbose : bool, optional |
| 1342 | + Print information along iterations |
| 1343 | + log : bool, optional |
| 1344 | + record log if True |
| 1345 | +
|
| 1346 | +
|
| 1347 | + Returns |
| 1348 | + ------- |
| 1349 | + gamma : (ns x nt) ndarray |
| 1350 | + Regularized optimal transportation matrix for the given parameters |
| 1351 | + log : dict |
| 1352 | + log dictionary return only if log==True in parameters |
| 1353 | +
|
| 1354 | + Examples |
| 1355 | + -------- |
| 1356 | +
|
| 1357 | + >>> n_s = 2 |
| 1358 | + >>> n_t = 2 |
| 1359 | + >>> reg = 0.1 |
| 1360 | + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) |
| 1361 | + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) |
| 1362 | + >>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False) |
| 1363 | + >>> print(emp_sinkhorn) |
| 1364 | + >>> [[4.99977301e-01 2.26989344e-05] |
| 1365 | + [2.26989344e-05 4.99977301e-01]] |
| 1366 | +
|
| 1367 | +
|
| 1368 | + References |
| 1369 | + ---------- |
| 1370 | +
|
| 1371 | + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 |
| 1372 | +
|
| 1373 | + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. |
| 1374 | +
|
| 1375 | + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. |
| 1376 | + ''' |
| 1377 | + |
| 1378 | + if a is None: |
| 1379 | + a = unif(np.shape(X_s)[0]) |
| 1380 | + if b is None: |
| 1381 | + b = unif(np.shape(X_t)[0]) |
| 1382 | + |
| 1383 | + M = dist(X_s, X_t, metric=metric) |
| 1384 | + |
| 1385 | + if log: |
| 1386 | + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) |
| 1387 | + return pi, log |
| 1388 | + else: |
| 1389 | + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) |
| 1390 | + return pi |
| 1391 | + |
| 1392 | + |
| 1393 | +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): |
| 1394 | + ''' |
| 1395 | + Solve the entropic regularization optimal transport problem from empirical |
| 1396 | + data and return the OT loss |
| 1397 | +
|
| 1398 | +
|
| 1399 | + The function solves the following optimization problem: |
| 1400 | +
|
| 1401 | + .. math:: |
| 1402 | + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) |
| 1403 | +
|
| 1404 | + s.t. \gamma 1 = a |
| 1405 | +
|
| 1406 | + \gamma^T 1= b |
| 1407 | +
|
| 1408 | + \gamma\geq 0 |
| 1409 | + where : |
| 1410 | +
|
| 1411 | + - :math:`M` is the (ns,nt) metric cost matrix |
| 1412 | + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` |
| 1413 | + - :math:`a` and :math:`b` are source and target weights (sum to 1) |
| 1414 | +
|
| 1415 | +
|
| 1416 | + Parameters |
| 1417 | + ---------- |
| 1418 | + X_s : np.ndarray (ns, d) |
| 1419 | + samples in the source domain |
| 1420 | + X_t : np.ndarray (nt, d) |
| 1421 | + samples in the target domain |
| 1422 | + reg : float |
| 1423 | + Regularization term >0 |
| 1424 | + a : np.ndarray (ns,) |
| 1425 | + samples weights in the source domain |
| 1426 | + b : np.ndarray (nt,) |
| 1427 | + samples weights in the target domain |
| 1428 | + numItermax : int, optional |
| 1429 | + Max number of iterations |
| 1430 | + stopThr : float, optional |
| 1431 | + Stop threshol on error (>0) |
| 1432 | + verbose : bool, optional |
| 1433 | + Print information along iterations |
| 1434 | + log : bool, optional |
| 1435 | + record log if True |
| 1436 | +
|
| 1437 | +
|
| 1438 | + Returns |
| 1439 | + ------- |
| 1440 | + gamma : (ns x nt) ndarray |
| 1441 | + Regularized optimal transportation matrix for the given parameters |
| 1442 | + log : dict |
| 1443 | + log dictionary return only if log==True in parameters |
| 1444 | +
|
| 1445 | + Examples |
| 1446 | + -------- |
| 1447 | +
|
| 1448 | + >>> n_s = 2 |
| 1449 | + >>> n_t = 2 |
| 1450 | + >>> reg = 0.1 |
| 1451 | + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) |
| 1452 | + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) |
| 1453 | + >>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False) |
| 1454 | + >>> print(loss_sinkhorn) |
| 1455 | + >>> [4.53978687e-05] |
| 1456 | +
|
| 1457 | +
|
| 1458 | + References |
| 1459 | + ---------- |
| 1460 | +
|
| 1461 | + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 |
| 1462 | +
|
| 1463 | + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. |
| 1464 | +
|
| 1465 | + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. |
| 1466 | + ''' |
| 1467 | + |
| 1468 | + if a is None: |
| 1469 | + a = unif(np.shape(X_s)[0]) |
| 1470 | + if b is None: |
| 1471 | + b = unif(np.shape(X_t)[0]) |
| 1472 | + |
| 1473 | + M = dist(X_s, X_t, metric=metric) |
| 1474 | + |
| 1475 | + if log: |
| 1476 | + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) |
| 1477 | + return sinkhorn_loss, log |
| 1478 | + else: |
| 1479 | + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) |
| 1480 | + return sinkhorn_loss |
| 1481 | + |
| 1482 | + |
| 1483 | +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): |
| 1484 | + ''' |
| 1485 | + Compute the sinkhorn divergence loss from empirical data |
| 1486 | +
|
| 1487 | + The function solves the following optimization problems and return the |
| 1488 | + sinkhorn divergence :math:`S`: |
| 1489 | +
|
| 1490 | + .. math:: |
| 1491 | +
|
| 1492 | + W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) |
| 1493 | +
|
| 1494 | + W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) |
| 1495 | +
|
| 1496 | + W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) |
| 1497 | +
|
| 1498 | + S &= W - 1/2 * (W_a + W_b) |
| 1499 | +
|
| 1500 | + .. math:: |
| 1501 | + s.t. \gamma 1 = a |
| 1502 | +
|
| 1503 | + \gamma^T 1= b |
| 1504 | +
|
| 1505 | + \gamma\geq 0 |
| 1506 | +
|
| 1507 | + \gamma_a 1 = a |
| 1508 | +
|
| 1509 | + \gamma_a^T 1= a |
| 1510 | +
|
| 1511 | + \gamma_a\geq 0 |
| 1512 | +
|
| 1513 | + \gamma_b 1 = b |
| 1514 | +
|
| 1515 | + \gamma_b^T 1= b |
| 1516 | +
|
| 1517 | + \gamma_b\geq 0 |
| 1518 | + where : |
| 1519 | +
|
| 1520 | + - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt)) |
| 1521 | + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` |
| 1522 | + - :math:`a` and :math:`b` are source and target weights (sum to 1) |
| 1523 | +
|
| 1524 | +
|
| 1525 | + Parameters |
| 1526 | + ---------- |
| 1527 | + X_s : np.ndarray (ns, d) |
| 1528 | + samples in the source domain |
| 1529 | + X_t : np.ndarray (nt, d) |
| 1530 | + samples in the target domain |
| 1531 | + reg : float |
| 1532 | + Regularization term >0 |
| 1533 | + a : np.ndarray (ns,) |
| 1534 | + samples weights in the source domain |
| 1535 | + b : np.ndarray (nt,) |
| 1536 | + samples weights in the target domain |
| 1537 | + numItermax : int, optional |
| 1538 | + Max number of iterations |
| 1539 | + stopThr : float, optional |
| 1540 | + Stop threshol on error (>0) |
| 1541 | + verbose : bool, optional |
| 1542 | + Print information along iterations |
| 1543 | + log : bool, optional |
| 1544 | + record log if True |
| 1545 | +
|
| 1546 | +
|
| 1547 | + Returns |
| 1548 | + ------- |
| 1549 | + gamma : (ns x nt) ndarray |
| 1550 | + Regularized optimal transportation matrix for the given parameters |
| 1551 | + log : dict |
| 1552 | + log dictionary return only if log==True in parameters |
| 1553 | +
|
| 1554 | + Examples |
| 1555 | + -------- |
| 1556 | +
|
| 1557 | + >>> n_s = 2 |
| 1558 | + >>> n_t = 4 |
| 1559 | + >>> reg = 0.1 |
| 1560 | + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) |
| 1561 | + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) |
| 1562 | + >>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg) |
| 1563 | + >>> print(emp_sinkhorn_div) |
| 1564 | + >>> [2.99977435] |
| 1565 | +
|
| 1566 | +
|
| 1567 | + References |
| 1568 | + ---------- |
| 1569 | +
|
| 1570 | + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 |
| 1571 | + ''' |
| 1572 | + if log: |
| 1573 | + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) |
| 1574 | + |
| 1575 | + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) |
| 1576 | + |
| 1577 | + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) |
| 1578 | + |
| 1579 | + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) |
| 1580 | + |
| 1581 | + log = {} |
| 1582 | + log['sinkhorn_loss_ab'] = sinkhorn_loss_ab |
| 1583 | + log['sinkhorn_loss_a'] = sinkhorn_loss_a |
| 1584 | + log['sinkhorn_loss_b'] = sinkhorn_loss_b |
| 1585 | + log['log_sinkhorn_ab'] = log_ab |
| 1586 | + log['log_sinkhorn_a'] = log_a |
| 1587 | + log['log_sinkhorn_b'] = log_b |
| 1588 | + |
| 1589 | + return max(0, sinkhorn_div), log |
| 1590 | + |
| 1591 | + else: |
| 1592 | + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) |
| 1593 | + |
| 1594 | + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) |
| 1595 | + |
| 1596 | + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) |
| 1597 | + |
| 1598 | + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) |
| 1599 | + return max(0, sinkhorn_div) |
0 commit comments