Skip to content

Commit a11b055

Browse files
qbarthelemycedricvincentcuazrflamary
authored
[DOC] Improve documentation rendering (#710)
* improve documentation rendering * Update RELEASES.md --------- Co-authored-by: Cédric Vincent-Cuaz <[email protected]> Co-authored-by: Rémi Flamary <[email protected]>
1 parent 9d00f96 commit a11b055

File tree

8 files changed

+125
-118
lines changed

8 files changed

+125
-118
lines changed

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Fixed numerical errors in `ot.gmm` (PR #690, Issue #689)
1414
- Add version number to the documentation (PR #696)
1515
- Update doc for default regularization in `ot.unbalanced` sinkhorn solvers (Issue #691, PR #700)
16+
- Clean documentation for `gromov`, `lp` and `unbalanced` folders (PR #710)
1617

1718
## 0.9.5
1819

ot/gromov/_lowrank.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def lowrank_gromov_wasserstein_samples(
9292
9393
where :
9494
95-
- :math: `A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain.
96-
- :math: `B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain.
97-
- :math: `\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan.
98-
- :math: `Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan.
99-
- :math: `g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan.
95+
- :math:`A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain.
96+
- :math:`B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain.
97+
- :math:`\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan.
98+
- :math:`Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan.
99+
- :math:`g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan.
100100
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1).
101-
- :math: `r` is the rank of the Gromov-Wasserstein plan.
102-
- :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem.
101+
- :math:`r` is the rank of the Gromov-Wasserstein plan.
102+
- :math:`\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem.
103103
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term.
104104
105105

ot/gromov/_partial.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1002,18 +1002,18 @@ def solve_partial_gromov_linesearch(
10021002
Parameters
10031003
----------
10041004
1005-
G : array-like, shape(ns,nt)
1005+
G : array-like, shape(ns, nt)
10061006
The transport map at a given iteration of the FW
1007-
deltaG : array-like (ns,nt)
1007+
deltaG : array-like, shape (ns, nt)
10081008
Difference between the optimal map `Gc` found by linearization in the
10091009
FW algorithm and the value at a given iteration
10101010
cost_G : float
10111011
Value of the cost at `G`
1012-
df_G : array-like (ns,nt)
1012+
df_G : array-like, shape (ns, nt)
10131013
Gradient of the GW cost at `G`
1014-
df_Gc : array-like (ns,nt)
1014+
df_Gc : array-like, shape (ns, nt)
10151015
Gradient of the GW cost at `Gc`
1016-
M : array-like (ns,nt)
1016+
M : array-like, shape (ns, nt)
10171017
Cost matrix between the features.
10181018
reg : float
10191019
Regularization parameter.
@@ -1032,7 +1032,7 @@ def solve_partial_gromov_linesearch(
10321032
nb of function call. Useless here
10331033
cost_G : float
10341034
The value of the cost for the next iteration
1035-
df_G : array-like (ns,nt)
1035+
df_G : array-like, shape (ns, nt)
10361036
Updated gradient of the GW cost
10371037
10381038
References
@@ -1173,7 +1173,7 @@ def entropic_partial_gromov_wasserstein(
11731173
11741174
Returns
11751175
-------
1176-
:math: `gamma` : (dim_a, dim_b) ndarray
1176+
:math:`gamma` : ndarray, shape (dim_a, dim_b)
11771177
Optimal transportation matrix for the given parameters
11781178
log : dict
11791179
log dictionary returned only if `log` is `True`

ot/lp/solver_1d.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def emd_1d(
160160
where :
161161
162162
- d is the metric
163-
- x_a and x_b are the samples
163+
- :math:`x_a` and :math:`x_b` are the samples
164164
- a and b are the sample weights
165165
166166
This implementation only supports metrics
@@ -170,21 +170,21 @@ def emd_1d(
170170
171171
Parameters
172172
----------
173-
x_a : (ns,) or (ns, 1) ndarray, float64
173+
x_a : ndarray of float64, shape (ns,) or (ns, 1)
174174
Source dirac locations (on the real line)
175-
x_b : (nt,) or (ns, 1) ndarray, float64
175+
x_b : ndarray of float64, shape (nt,) or (ns, 1)
176176
Target dirac locations (on the real line)
177-
a : (ns,) ndarray, float64, optional
177+
a : ndarray of float64, shape (ns,), optional
178178
Source histogram (default is uniform weight)
179-
b : (nt,) ndarray, float64, optional
179+
b : ndarray of float64, shape (nt,), optional
180180
Target histogram (default is uniform weight)
181181
metric: str, optional (default='sqeuclidean')
182182
Metric to be used. Only works with either of the strings
183183
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
184184
p: float, optional (default=1.0)
185185
The p-norm to apply for if metric='minkowski'
186186
dense: boolean, optional (default=True)
187-
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
187+
If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt).
188188
Otherwise returns a sparse representation using scipy's `coo_matrix`
189189
format. Due to implementation details, this function runs faster when
190190
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
@@ -198,7 +198,7 @@ def emd_1d(
198198
199199
Returns
200200
-------
201-
gamma: (ns, nt) ndarray
201+
gamma: ndarray, shape (ns, nt)
202202
Optimal transportation matrix for the given parameters
203203
log: dict
204204
If input log is True, a dictionary containing the cost
@@ -318,7 +318,7 @@ def emd2_1d(
318318
where :
319319
320320
- d is the metric
321-
- x_a and x_b are the samples
321+
- :math:`x_a` and :math:`x_b` are the samples
322322
- a and b are the sample weights
323323
324324
This implementation only supports metrics
@@ -328,21 +328,21 @@ def emd2_1d(
328328
329329
Parameters
330330
----------
331-
x_a : (ns,) or (ns, 1) ndarray, float64
331+
x_a : ndarray of float64, shape (ns,) or (ns, 1)
332332
Source dirac locations (on the real line)
333-
x_b : (nt,) or (ns, 1) ndarray, float64
333+
x_b : ndarray of float64, shape (nt,) or (ns, 1)
334334
Target dirac locations (on the real line)
335-
a : (ns,) ndarray, float64, optional
335+
a : ndarray of float64, shape (ns,), optional
336336
Source histogram (default is uniform weight)
337-
b : (nt,) ndarray, float64, optional
337+
b : ndarray of float64, shape (nt,), optional
338338
Target histogram (default is uniform weight)
339339
metric: str, optional (default='sqeuclidean')
340340
Metric to be used. Only works with either of the strings
341341
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
342342
p: float, optional (default=1.0)
343343
The p-norm to apply for if metric='minkowski'
344344
dense: boolean, optional (default=True)
345-
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
345+
If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt).
346346
Otherwise returns a sparse representation using scipy's `coo_matrix`
347347
format. Only used if log is set to True. Due to implementation details,
348348
this function runs faster when dense is set to False.
@@ -405,9 +405,9 @@ def roll_cols(M, shifts):
405405
406406
Parameters
407407
----------
408-
M : (nr, nc) ndarray
408+
M : ndarray, shape (nr, nc)
409409
Matrix to shift
410-
shifts: int or (nr,) ndarray
410+
shifts: int or ndarray, shape (nr,)
411411
412412
Returns
413413
-------
@@ -1046,7 +1046,7 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
10461046
10471047
Parameters
10481048
----------
1049-
u_values: ndarray, shape (n, ...)
1049+
u_values : ndarray, shape (n, ...)
10501050
Samples
10511051
u_weights : ndarray, shape (n, ...), optional
10521052
samples weights in the source domain

ot/plot.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def plot1D_mat(
3232
r"""Plot matrix :math:`\mathbf{M}` with the source and target 1D distributions.
3333
3434
Creates a subplot with the source distribution :math:`\mathbf{a}` and target
35-
distribution :math:`\mathbf{b}`t.
35+
distribution :math:`\mathbf{b}`.
3636
In 'yx' mode (default), the source is on the left and
3737
the target on the top, and in 'xy' mode, source on the bottom (upside
3838
down) and the target on the left.
@@ -69,8 +69,9 @@ def plot1D_mat(
6969
ax2 : target plot ax
7070
ax3 : coupling plot ax
7171
72-
.. seealso::
73-
:func:`rescale_for_imshow_plot`
72+
See Also
73+
--------
74+
:func:`rescale_for_imshow_plot`
7475
"""
7576
assert plot_style in ["yx", "xy"], "plot_style should be 'yx' or 'xy'"
7677
na, nb = M.shape
@@ -188,8 +189,9 @@ def rescale_for_imshow_plot(x, y, n, m=None, a_y=None, b_y=None):
188189
yr : ndarray, shape (nx,)
189190
Rescaled y values (due to slicing, may have less elements than y)
190191
191-
.. seealso::
192-
:func:`plot1D_mat`
192+
See Also
193+
--------
194+
:func:`plot1D_mat`
193195
194196
"""
195197
# slice over the y values that are in the y range

ot/solvers.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def solve(
7979
8080
Parameters
8181
----------
82-
M : array_like, shape (dim_a, dim_b)
82+
M : array-like, shape (dim_a, dim_b)
8383
Loss matrix
8484
a : array-like, shape (dim_a,), optional
8585
Samples weights in the source domain (default is uniform)
@@ -88,10 +88,10 @@ def solve(
8888
reg : float, optional
8989
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
9090
OT)
91-
c : array-like (dim_a, dim_b), optional (default=None)
91+
c : array-like, shape (dim_a, dim_b), optional (default=None)
9292
Reference measure for the regularization.
9393
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
94-
If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
94+
If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
9595
reg_type : str, optional
9696
Type of regularization :math:`R` either "KL", "L2", "entropy",
9797
by default "KL". a tuple of functions can be provided for general
@@ -116,9 +116,9 @@ def solve(
116116
Number of OMP threads for exact OT solver, by default 1
117117
max_iter : int, optional
118118
Maximum number of iterations, by default None (default values in each solvers)
119-
plan_init : array_like, shape (dim_a, dim_b), optional
119+
plan_init : array-like, shape (dim_a, dim_b), optional
120120
Initialization of the OT plan for iterative methods, by default None
121-
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
121+
potentials_init : (array-like(dim_a,),array-like(dim_b,)), optional
122122
Initialization of the OT dual potentials for iterative methods, by default None
123123
tol : _type_, optional
124124
Tolerance for solution precision, by default None (default values in each solvers)
@@ -628,11 +628,11 @@ def solve_gromov(
628628
629629
Parameters
630630
----------
631-
Ca : array_like, shape (dim_a, dim_a)
631+
Ca : array-like, shape (dim_a, dim_a)
632632
Cost matrix in the source domain
633-
Cb : array_like, shape (dim_b, dim_b)
633+
Cb : array-like, shape (dim_b, dim_b)
634634
Cost matrix in the target domain
635-
M : array_like, shape (dim_a, dim_b), optional
635+
M : array-like, shape (dim_a, dim_b), optional
636636
Linear cost matrix for Fused Gromov-Wasserstein (default is None).
637637
a : array-like, shape (dim_a,), optional
638638
Samples weights in the source domain (default is uniform)
@@ -669,7 +669,7 @@ def solve_gromov(
669669
max_iter : int, optional
670670
Maximum number of iterations, by default None (default values in each
671671
solvers)
672-
plan_init : array_like, shape (dim_a, dim_b), optional
672+
plan_init : array-like, shape (dim_a, dim_b), optional
673673
Initialization of the OT plan for iterative methods, by default None
674674
tol : float, optional
675675
Tolerance for solution precision, by default None (default values in
@@ -1342,10 +1342,10 @@ def solve_sample(
13421342
reg : float, optional
13431343
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
13441344
OT)
1345-
c : array-like (dim_a, dim_b), optional (default=None)
1345+
c : array-like, shape (dim_a, dim_b), optional (default=None)
13461346
Reference measure for the regularization.
13471347
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
1348-
If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
1348+
If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
13491349
reg_type : str, optional
13501350
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
13511351
unbalanced : float or indexable object of length 1 or 2
@@ -1374,13 +1374,13 @@ def solve_sample(
13741374
Number of OMP threads for exact OT solver, by default 1
13751375
max_iter : int, optional
13761376
Maximum number of iteration, by default None (default values in each solvers)
1377-
plan_init : array_like, shape (dim_a, dim_b), optional
1377+
plan_init : array-like, shape (dim_a, dim_b), optional
13781378
Initialization of the OT plan for iterative methods, by default None
13791379
rank : int, optional
13801380
Rank of the OT matrix for lazy solers (method='factored'), by default 100
13811381
scaling : float, optional
13821382
Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95
1383-
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
1383+
potentials_init : (array-like(dim_a,),array-like(dim_b,)), optional
13841384
Initialization of the OT dual potentials for iterative methods, by default None
13851385
tol : _type_, optional
13861386
Tolerance for solution precision, by default None (default values in each solvers)
@@ -1511,7 +1511,7 @@ def solve_sample(
15111511
.. math::
15121512
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
15131513
1514-
with M_{i,j} = d(x_i,y_j)
1514+
\text{with} \ M_{i,j} = d(x_i,y_j)
15151515
15161516
can be solved with the following code:
15171517
@@ -1530,7 +1530,7 @@ def solve_sample(
15301530
.. math::
15311531
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
15321532
1533-
with M_{i,j} = d(x_i,y_j)
1533+
\text{with} \ M_{i,j} = d(x_i,y_j)
15341534
15351535
can be solved with the following code:
15361536

0 commit comments

Comments
 (0)