|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 | """
|
3 |
| -Domain adaptation with optimal transport |
| 3 | +Dimension reduction with optimal transport |
4 | 4 | """
|
5 | 5 |
|
6 |
| - |
7 | 6 | import autograd.numpy as np
|
8 | 7 | from pymanopt.manifolds import Stiefel
|
9 | 8 | from pymanopt import Problem
|
10 | 9 | from pymanopt.solvers import SteepestDescent, TrustRegions
|
11 | 10 |
|
12 | 11 | def dist(x1,x2):
|
13 |
| - """ Compute squared euclidena distance between samples |
| 12 | + """ Compute squared euclidean distance between samples |
14 | 13 | """
|
15 | 14 | x1p2=np.sum(np.square(x1),1)
|
16 | 15 | x2p2=np.sum(np.square(x2),1)
|
@@ -40,18 +39,66 @@ def split_classes(X,y):
|
40 | 39 |
|
41 | 40 | def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
|
42 | 41 | """
|
43 |
| - Wasserstein Discriminant Analysis |
| 42 | + Wasserstein Discriminant Analysis [11]_ |
44 | 43 |
|
45 | 44 | The function solves the following optimization problem:
|
46 | 45 |
|
47 | 46 | .. math::
|
48 |
| - P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)} |
| 47 | + P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)} |
49 | 48 |
|
50 | 49 | where :
|
51 |
| -
|
| 50 | + |
| 51 | + - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold |
52 | 52 | - :math:`W` is entropic regularized Wasserstein distances
|
53 | 53 | - :math:`X^i` are samples in the dataset corresponding to class i
|
54 | 54 |
|
| 55 | + Parameters |
| 56 | + ---------- |
| 57 | + a : np.ndarray (ns,) |
| 58 | + samples weights in the source domain |
| 59 | + b : np.ndarray (nt,) |
| 60 | + samples in the target domain |
| 61 | + M : np.ndarray (ns,nt) |
| 62 | + loss matrix |
| 63 | + reg : float |
| 64 | + Regularization term >0 |
| 65 | + numItermax : int, optional |
| 66 | + Max number of iterations |
| 67 | + stopThr : float, optional |
| 68 | + Stop threshol on error (>0) |
| 69 | + verbose : bool, optional |
| 70 | + Print information along iterations |
| 71 | + log : bool, optional |
| 72 | + record log if True |
| 73 | +
|
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + gamma : (ns x nt) ndarray |
| 78 | + Optimal transportation matrix for the given parameters |
| 79 | + log : dict |
| 80 | + log dictionary return only if log==True in parameters |
| 81 | +
|
| 82 | + Examples |
| 83 | + -------- |
| 84 | +
|
| 85 | + >>> import ot |
| 86 | + >>> a=[.5,.5] |
| 87 | + >>> b=[.5,.5] |
| 88 | + >>> M=[[0.,1.],[1.,0.]] |
| 89 | + >>> ot.sinkhorn(a,b,M,1) |
| 90 | + array([[ 0.36552929, 0.13447071], |
| 91 | + [ 0.13447071, 0.36552929]]) |
| 92 | +
|
| 93 | +
|
| 94 | + References |
| 95 | + ---------- |
| 96 | +
|
| 97 | + .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. |
| 98 | +
|
| 99 | + |
| 100 | + |
| 101 | + |
55 | 102 | """
|
56 | 103 |
|
57 | 104 | mx=np.mean(X)
|
|
0 commit comments