|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +======================== |
| 4 | +OT for domain adaptation |
| 5 | +======================== |
| 6 | +
|
| 7 | +This example introduces a domain adaptation in a 2D setting and the 4 OTDA |
| 8 | +approaches currently supported in POT. |
| 9 | +
|
| 10 | +""" |
| 11 | + |
| 12 | +# Authors: Remi Flamary <[email protected]> |
| 13 | +# Stanislas Chambon <[email protected]> |
| 14 | +# |
| 15 | +# License: MIT License |
| 16 | + |
| 17 | +import matplotlib.pylab as pl |
| 18 | +import ot |
| 19 | + |
| 20 | + |
| 21 | +############################################################################## |
| 22 | +# generate data |
| 23 | +############################################################################## |
| 24 | + |
| 25 | +n_source_samples = 150 |
| 26 | +n_target_samples = 150 |
| 27 | + |
| 28 | +Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples) |
| 29 | +Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples) |
| 30 | + |
| 31 | + |
| 32 | +############################################################################## |
| 33 | +# Instantiate the different transport algorithms and fit them |
| 34 | +############################################################################## |
| 35 | + |
| 36 | +# EMD Transport |
| 37 | +ot_emd = ot.da.EMDTransport() |
| 38 | +ot_emd.fit(Xs=Xs, Xt=Xt) |
| 39 | + |
| 40 | +# Sinkhorn Transport |
| 41 | +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) |
| 42 | +ot_sinkhorn.fit(Xs=Xs, Xt=Xt) |
| 43 | + |
| 44 | +# Sinkhorn Transport with Group lasso regularization |
| 45 | +ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0) |
| 46 | +ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt) |
| 47 | + |
| 48 | +# Sinkhorn Transport with Group lasso regularization l1l2 |
| 49 | +ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20, |
| 50 | + verbose=True) |
| 51 | +ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt) |
| 52 | + |
| 53 | +# transport source samples onto target samples |
| 54 | +transp_Xs_emd = ot_emd.transform(Xs=Xs) |
| 55 | +transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) |
| 56 | +transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs) |
| 57 | +transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs) |
| 58 | + |
| 59 | + |
| 60 | +############################################################################## |
| 61 | +# Fig 1 : plots source and target samples |
| 62 | +############################################################################## |
| 63 | + |
| 64 | +pl.figure(1, figsize=(10, 5)) |
| 65 | +pl.subplot(1, 2, 1) |
| 66 | +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') |
| 67 | +pl.xticks([]) |
| 68 | +pl.yticks([]) |
| 69 | +pl.legend(loc=0) |
| 70 | +pl.title('Source samples') |
| 71 | + |
| 72 | +pl.subplot(1, 2, 2) |
| 73 | +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') |
| 74 | +pl.xticks([]) |
| 75 | +pl.yticks([]) |
| 76 | +pl.legend(loc=0) |
| 77 | +pl.title('Target samples') |
| 78 | +pl.tight_layout() |
| 79 | + |
| 80 | + |
| 81 | +############################################################################## |
| 82 | +# Fig 2 : plot optimal couplings and transported samples |
| 83 | +############################################################################## |
| 84 | + |
| 85 | +param_img = {'interpolation': 'nearest', 'cmap': 'spectral'} |
| 86 | + |
| 87 | +pl.figure(2, figsize=(15, 8)) |
| 88 | +pl.subplot(2, 4, 1) |
| 89 | +pl.imshow(ot_emd.coupling_, **param_img) |
| 90 | +pl.xticks([]) |
| 91 | +pl.yticks([]) |
| 92 | +pl.title('Optimal coupling\nEMDTransport') |
| 93 | + |
| 94 | +pl.subplot(2, 4, 2) |
| 95 | +pl.imshow(ot_sinkhorn.coupling_, **param_img) |
| 96 | +pl.xticks([]) |
| 97 | +pl.yticks([]) |
| 98 | +pl.title('Optimal coupling\nSinkhornTransport') |
| 99 | + |
| 100 | +pl.subplot(2, 4, 3) |
| 101 | +pl.imshow(ot_lpl1.coupling_, **param_img) |
| 102 | +pl.xticks([]) |
| 103 | +pl.yticks([]) |
| 104 | +pl.title('Optimal coupling\nSinkhornLpl1Transport') |
| 105 | + |
| 106 | +pl.subplot(2, 4, 4) |
| 107 | +pl.imshow(ot_l1l2.coupling_, **param_img) |
| 108 | +pl.xticks([]) |
| 109 | +pl.yticks([]) |
| 110 | +pl.title('Optimal coupling\nSinkhornL1l2Transport') |
| 111 | + |
| 112 | +pl.subplot(2, 4, 5) |
| 113 | +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
| 114 | + label='Target samples', alpha=0.3) |
| 115 | +pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, |
| 116 | + marker='+', label='Transp samples', s=30) |
| 117 | +pl.xticks([]) |
| 118 | +pl.yticks([]) |
| 119 | +pl.title('Transported samples\nEmdTransport') |
| 120 | +pl.legend(loc="lower left") |
| 121 | + |
| 122 | +pl.subplot(2, 4, 6) |
| 123 | +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
| 124 | + label='Target samples', alpha=0.3) |
| 125 | +pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, |
| 126 | + marker='+', label='Transp samples', s=30) |
| 127 | +pl.xticks([]) |
| 128 | +pl.yticks([]) |
| 129 | +pl.title('Transported samples\nSinkhornTransport') |
| 130 | + |
| 131 | +pl.subplot(2, 4, 7) |
| 132 | +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
| 133 | + label='Target samples', alpha=0.3) |
| 134 | +pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys, |
| 135 | + marker='+', label='Transp samples', s=30) |
| 136 | +pl.xticks([]) |
| 137 | +pl.yticks([]) |
| 138 | +pl.title('Transported samples\nSinkhornLpl1Transport') |
| 139 | + |
| 140 | +pl.subplot(2, 4, 8) |
| 141 | +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', |
| 142 | + label='Target samples', alpha=0.3) |
| 143 | +pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys, |
| 144 | + marker='+', label='Transp samples', s=30) |
| 145 | +pl.xticks([]) |
| 146 | +pl.yticks([]) |
| 147 | +pl.title('Transported samples\nSinkhornL1l2Transport') |
| 148 | +pl.tight_layout() |
| 149 | + |
| 150 | +pl.show() |
0 commit comments