Skip to content

Commit 65de6fc

Browse files
committed
pass on examples | introduced RandomState
1 parent a29e22d commit 65de6fc

File tree

5 files changed

+59
-22
lines changed

5 files changed

+59
-22
lines changed

examples/da/plot_otda_classes.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,23 @@
1515
# License: MIT License
1616

1717
import matplotlib.pylab as pl
18-
import numpy as np
1918
import ot
2019

21-
np.random.seed(42)
2220

23-
# number of source and target points to generate
24-
ns = 150
25-
nt = 150
21+
##############################################################################
22+
# generate data
23+
##############################################################################
24+
25+
n_source_samples = 150
26+
n_target_samples = 150
27+
28+
Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples)
29+
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples)
2630

27-
Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
28-
Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
2931

32+
##############################################################################
3033
# Instantiate the different transport algorithms and fit them
34+
##############################################################################
3135

3236
# EMD Transport
3337
ot_emd = ot.da.EMDTransport()
@@ -52,6 +56,7 @@
5256
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
5357
transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
5458

59+
5560
##############################################################################
5661
# Fig 1 : plots source and target samples
5762
##############################################################################
@@ -72,6 +77,7 @@
7277
pl.title('Target samples')
7378
pl.tight_layout()
7479

80+
7581
##############################################################################
7682
# Fig 2 : plot optimal couplings and transported samples
7783
##############################################################################

examples/da/plot_otda_color_images.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import matplotlib.pylab as pl
2323
import ot
2424

25-
np.random.seed(42)
25+
26+
r = np.random.RandomState(42)
2627

2728

2829
def im2mat(I):
@@ -39,6 +40,10 @@ def minmax(I):
3940
return np.clip(I, 0, 1)
4041

4142

43+
##############################################################################
44+
# generate data
45+
##############################################################################
46+
4247
# Loading images
4348
I1 = ndimage.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
4449
I2 = ndimage.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
@@ -48,12 +53,17 @@ def minmax(I):
4853

4954
# training samples
5055
nb = 1000
51-
idx1 = np.random.randint(X1.shape[0], size=(nb,))
52-
idx2 = np.random.randint(X2.shape[0], size=(nb,))
56+
idx1 = r.randint(X1.shape[0], size=(nb,))
57+
idx2 = r.randint(X2.shape[0], size=(nb,))
5358

5459
Xs = X1[idx1, :]
5560
Xt = X2[idx2, :]
5661

62+
63+
##############################################################################
64+
# Instantiate the different transport algorithms and fit them
65+
##############################################################################
66+
5767
# EMDTransport
5868
ot_emd = ot.da.EMDTransport()
5969
ot_emd.fit(Xs=Xs, Xt=Xt)
@@ -75,6 +85,7 @@ def minmax(I):
7585
I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
7686
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
7787

88+
7889
##############################################################################
7990
# plot original image
8091
##############################################################################
@@ -91,6 +102,7 @@ def minmax(I):
91102
pl.axis('off')
92103
pl.title('Image 2')
93104

105+
94106
##############################################################################
95107
# scatter plot of colors
96108
##############################################################################
@@ -112,6 +124,7 @@ def minmax(I):
112124
pl.title('Image 2')
113125
pl.tight_layout()
114126

127+
115128
##############################################################################
116129
# plot new images
117130
##############################################################################

examples/da/plot_otda_d2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
# License: MIT License
2020

2121
import matplotlib.pylab as pl
22-
import numpy as np
2322
import ot
2423

25-
np.random.seed(42)
24+
25+
##############################################################################
26+
# generate data
27+
##############################################################################
2628

2729
n_samples_source = 150
2830
n_samples_target = 150
@@ -33,7 +35,10 @@
3335
# Cost matrix
3436
M = ot.dist(Xs, Xt, metric='sqeuclidean')
3537

38+
39+
##############################################################################
3640
# Instantiate the different transport algorithms and fit them
41+
##############################################################################
3742

3843
# EMD Transport
3944
ot_emd = ot.da.EMDTransport()
@@ -52,6 +57,7 @@
5257
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
5358
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
5459

60+
5561
##############################################################################
5662
# Fig 1 : plots source and target samples + matrix of pairwise distance
5763
##############################################################################
@@ -78,6 +84,7 @@
7884
pl.title('Matrix of pairwise distances')
7985
pl.tight_layout()
8086

87+
8188
##############################################################################
8289
# Fig 2 : plots optimal couplings for the different methods
8390
##############################################################################
@@ -127,6 +134,7 @@
127134
pl.title('Main coupling coefficients\nSinkhornLpl1Transport')
128135
pl.tight_layout()
129136

137+
130138
##############################################################################
131139
# Fig 3 : plot transported samples
132140
##############################################################################

examples/da/plot_otda_mapping.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,31 @@
2323
import ot
2424

2525

26-
np.random.seed(42)
27-
2826
##############################################################################
29-
# generate
27+
# generate data
3028
##############################################################################
3129

32-
n = 100 # nb samples in source and target datasets
30+
n_source_samples = 100
31+
n_target_samples = 100
3332
theta = 2 * np.pi / 20
3433
noise_level = 0.1
35-
Xs, ys = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level)
36-
Xs_new, _ = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level)
34+
35+
Xs, ys = ot.datasets.get_data_classif(
36+
'gaussrot', n_source_samples, nz=noise_level)
37+
Xs_new, _ = ot.datasets.get_data_classif(
38+
'gaussrot', n_source_samples, nz=noise_level)
3739
Xt, yt = ot.datasets.get_data_classif(
38-
'gaussrot', n, theta=theta, nz=noise_level)
40+
'gaussrot', n_target_samples, theta=theta, nz=noise_level)
3941

4042
# one of the target mode changes its variance (no linear mapping)
4143
Xt[yt == 2] *= 3
4244
Xt = Xt + 4
4345

4446

47+
##############################################################################
48+
# Instantiate the different transport algorithms and fit them
49+
##############################################################################
50+
4551
# MappingTransport with linear kernel
4652
ot_mapping_linear = ot.da.MappingTransport(
4753
kernel="linear", mu=1e0, eta=1e-8, bias=True,
@@ -80,6 +86,7 @@
8086
pl.legend(loc=0)
8187
pl.title('Source and target distributions')
8288

89+
8390
##############################################################################
8491
# plot transported samples
8592
##############################################################################

examples/da/plot_otda_mapping_colors_images.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import matplotlib.pylab as pl
2424
import ot
2525

26-
np.random.seed(42)
26+
r = np.random.RandomState(42)
2727

2828

2929
def im2mat(I):
@@ -54,8 +54,8 @@ def minmax(I):
5454

5555
# training samples
5656
nb = 1000
57-
idx1 = np.random.randint(X1.shape[0], size=(nb,))
58-
idx2 = np.random.randint(X2.shape[0], size=(nb,))
57+
idx1 = r.randint(X1.shape[0], size=(nb,))
58+
idx2 = r.randint(X2.shape[0], size=(nb,))
5959

6060
Xs = X1[idx1, :]
6161
Xt = X2[idx2, :]
@@ -91,6 +91,7 @@ def minmax(I):
9191
X1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping
9292
Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
9393

94+
9495
##############################################################################
9596
# plot original images
9697
##############################################################################
@@ -107,6 +108,7 @@ def minmax(I):
107108
pl.title('Image 2')
108109
pl.tight_layout()
109110

111+
110112
##############################################################################
111113
# plot pixel values distribution
112114
##############################################################################
@@ -128,6 +130,7 @@ def minmax(I):
128130
pl.title('Image 2')
129131
pl.tight_layout()
130132

133+
131134
##############################################################################
132135
# plot transformed images
133136
##############################################################################

0 commit comments

Comments
 (0)