Skip to content

Commit c1598ec

Browse files
committed
DOC: jax.lax.top_k: fix rendering of return values
1 parent 27de854 commit c1598ec

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

jax/_src/lax/lax.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,12 +1232,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
12321232
k: integer specifying the number of top entries.
12331233
12341234
Returns:
1235-
values: array containing the top k values along the last axis.
1236-
indices: array containing the indices corresponding to values.
1235+
A tuple ``(values, indices)`` where
1236+
1237+
- ``values`` is an array containing the top k values along the last axis.
1238+
- ``indices`` is an array containing the indices corresponding to values.
12371239
12381240
See also:
1239-
- :func:`jax.lax.approx_max_k`
1240-
- :func:`jax.lax.approx_min_k`
1241+
- :func:`jax.lax.approx_max_k`
1242+
- :func:`jax.lax.approx_min_k`
1243+
1244+
Example:
1245+
Find the largest three values, and their indices, within an array:
1246+
1247+
>>> x = jnp.array([9., 3., 6., 4., 10.])
1248+
>>> values, indices = jax.lax.top_k(x, 3)
1249+
>>> values
1250+
Array([10., 9., 6.], dtype=float32)
1251+
>>> indices
1252+
Array([4, 0, 2], dtype=int32)
12411253
"""
12421254
if core.is_constant_dim(k):
12431255
k = int(k)

0 commit comments

Comments
 (0)