Skip to content

Commit a29e22d

Browse files
committed
addressed AG comments + adding random seed
1 parent 7d3fc95 commit a29e22d

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed

examples/da/plot_otda_classes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
# License: MIT License
1616

1717
import matplotlib.pylab as pl
18+
import numpy as np
1819
import ot
1920

21+
np.random.seed(42)
2022

2123
# number of source and target points to generate
2224
ns = 150

examples/da/plot_otda_color_images.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import numpy as np
2121
from scipy import ndimage
2222
import matplotlib.pylab as pl
23-
2423
import ot
2524

25+
np.random.seed(42)
26+
2627

2728
def im2mat(I):
2829
"""Converts and image to matrix (one pixel per line)"""

examples/da/plot_otda_d2.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@
1919
# License: MIT License
2020

2121
import matplotlib.pylab as pl
22+
import numpy as np
2223
import ot
2324

24-
# number of source and target points to generate
25-
ns = 150
26-
nt = 150
25+
np.random.seed(42)
2726

28-
Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
29-
Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
27+
n_samples_source = 150
28+
n_samples_target = 150
29+
30+
Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
31+
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
3032

3133
# Cost matrix
32-
M = ot.dist(Xs, Xt)
34+
M = ot.dist(Xs, Xt, metric='sqeuclidean')
3335

3436
# Instantiate the different transport algorithms and fit them
3537

examples/da/plot_otda_mapping.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,19 @@
2323
import ot
2424

2525

26-
np.random.seed(0)
26+
np.random.seed(42)
2727

2828
##############################################################################
2929
# generate
3030
##############################################################################
3131

3232
n = 100 # nb samples in source and target datasets
3333
theta = 2 * np.pi / 20
34-
nz = 0.1
35-
Xs, ys = ot.datasets.get_data_classif('gaussrot', n, nz=nz)
36-
Xs_new, _ = ot.datasets.get_data_classif('gaussrot', n, nz=nz)
37-
Xt, yt = ot.datasets.get_data_classif('gaussrot', n, theta=theta, nz=nz)
34+
noise_level = 0.1
35+
Xs, ys = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level)
36+
Xs_new, _ = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level)
37+
Xt, yt = ot.datasets.get_data_classif(
38+
'gaussrot', n, theta=theta, nz=noise_level)
3839

3940
# one of the target mode changes its variance (no linear mapping)
4041
Xt[yt == 2] *= 3
@@ -46,8 +47,7 @@
4647
kernel="linear", mu=1e0, eta=1e-8, bias=True,
4748
max_iter=20, verbose=True)
4849

49-
ot_mapping_linear.fit(
50-
Xs=Xs, Xt=Xt)
50+
ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
5151

5252
# for original source samples, transform applies barycentric mapping
5353
transp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)

examples/da/plot_otda_mapping_colors_images.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import matplotlib.pylab as pl
2424
import ot
2525

26+
np.random.seed(42)
27+
2628

2729
def im2mat(I):
2830
"""Converts and image to matrix (one pixel per line)"""

0 commit comments

Comments
 (0)