Skip to content

Commit 34c49d3

Browse files
Document torch.quantile interpolation kwarg (pytorch#70637)
Summary: clone of pytorch#59397 This PR documents the interpolation kwarg parameter added in pytorch#49267. Now that the forward compatibility period is over, we can expose this parameter. Pull Request resolved: pytorch#70637 Reviewed By: jbschlosser Differential Revision: D33411707 Pulled By: anjali411 fbshipit-source-id: f5f2d0a6739b3a855bbdf58fc671ac2f0342ce69
1 parent 616afcf commit 34c49d3

File tree

7 files changed

+54
-142
lines changed

7 files changed

+54
-142
lines changed

aten/src/ATen/autocast_mode.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -521,14 +521,10 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
521521
KERNEL_CPU(ADD_NS(prod), "prod", Tensor(const Tensor &, c10::optional<at::ScalarType>), fp32)
522522
KERNEL_CPU(ADD_NS(prod), "prod.dim_int", Tensor(const Tensor &, int64_t, bool, c10::optional<at::ScalarType>), fp32)
523523
KERNEL_CPU(ADD_NS(prod), "prod.dim_Dimname", Tensor(const Tensor &, at::Dimname, bool, c10::optional<at::ScalarType>), fp32)
524-
KERNEL_CPU(ADD_NS(quantile), "quantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool), fp32)
525-
KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool), fp32)
526-
KERNEL_CPU(ADD_NS(quantile), "quantile.new", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
527-
KERNEL_CPU(ADD_NS(quantile), "quantile.new_scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
528-
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool), fp32)
529-
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool), fp32)
530-
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.new", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
531-
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.new_scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
524+
KERNEL_CPU(ADD_NS(quantile), "quantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
525+
KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
526+
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
527+
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
532528
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
533529
KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
534530
KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)

aten/src/ATen/native/Sorting.cpp

+8-93
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,6 @@ Tensor median_impl(const Tensor& self, bool ignore_nan) {
515515

516516
} // namespace
517517

518-
519518
Tensor& quantile_out(
520519
const Tensor& self,
521520
const Tensor& q,
@@ -527,8 +526,7 @@ Tensor& quantile_out(
527526
out,
528527
self,
529528
q,
530-
// NOLINTNEXTLINE(performance-move-const-arg)
531-
std::move(dim),
529+
dim,
532530
keepdim,
533531
get_quantile_interpolation_mode(interpolation),
534532
/*ignore_nan=*/false);
@@ -547,8 +545,7 @@ Tensor& quantile_out(
547545
return at::native::quantile_out(
548546
self,
549547
at::scalar_tensor(q, self.options()),
550-
// NOLINTNEXTLINE(performance-move-const-arg)
551-
std::move(dim),
548+
dim,
552549
keepdim,
553550
interpolation,
554551
out);
@@ -565,8 +562,7 @@ Tensor quantile(
565562
out,
566563
self,
567564
q,
568-
// NOLINTNEXTLINE(performance-move-const-arg)
569-
std::move(dim),
565+
dim,
570566
keepdim,
571567
get_quantile_interpolation_mode(interpolation),
572568
/*ignore_nan=*/false);
@@ -582,8 +578,7 @@ Tensor quantile(
582578
TORCH_CHECK(
583579
q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
584580
return at::native::quantile(
585-
// NOLINTNEXTLINE(performance-move-const-arg)
586-
self, at::scalar_tensor(q, self.options()), std::move(dim), keepdim, interpolation);
581+
self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation);
587582
}
588583

589584
Tensor& nanquantile_out(
@@ -597,8 +592,7 @@ Tensor& nanquantile_out(
597592
out,
598593
self,
599594
q,
600-
// NOLINTNEXTLINE(performance-move-const-arg)
601-
std::move(dim),
595+
dim,
602596
keepdim,
603597
get_quantile_interpolation_mode(interpolation),
604598
/*ignore_nan=*/true);
@@ -617,8 +611,7 @@ Tensor& nanquantile_out(
617611
return at::native::nanquantile_out(
618612
self,
619613
at::scalar_tensor(q, self.options()),
620-
// NOLINTNEXTLINE(performance-move-const-arg)
621-
std::move(dim),
614+
dim,
622615
keepdim,
623616
interpolation,
624617
out);
@@ -635,8 +628,7 @@ Tensor nanquantile(
635628
out,
636629
self,
637630
q,
638-
// NOLINTNEXTLINE(performance-move-const-arg)
639-
std::move(dim),
631+
dim,
640632
keepdim,
641633
get_quantile_interpolation_mode(interpolation),
642634
/*ignore_nan=*/true);
@@ -652,84 +644,7 @@ Tensor nanquantile(
652644
TORCH_CHECK(
653645
q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
654646
return at::native::nanquantile(
655-
// NOLINTNEXTLINE(performance-move-const-arg)
656-
self, at::scalar_tensor(q, self.options()), std::move(dim), keepdim, interpolation);
657-
}
658-
659-
Tensor& quantile_out(
660-
const Tensor& self,
661-
const Tensor& q,
662-
optional<int64_t> dim,
663-
bool keepdim,
664-
Tensor& out) {
665-
// NOLINTNEXTLINE(performance-move-const-arg)
666-
return at::native::quantile_out(self, q, std::move(dim), keepdim, "linear", out);
667-
}
668-
669-
Tensor& quantile_out(
670-
const Tensor& self,
671-
double q,
672-
optional<int64_t> dim,
673-
bool keepdim,
674-
Tensor& out) {
675-
// NOLINTNEXTLINE(performance-move-const-arg)
676-
return at::native::quantile_out(self, q, std::move(dim), keepdim, "linear", out);
677-
}
678-
679-
Tensor quantile(
680-
const Tensor& self,
681-
const Tensor& q,
682-
optional<int64_t> dim,
683-
bool keepdim) {
684-
// NOLINTNEXTLINE(performance-move-const-arg)
685-
return at::native::quantile(self, q, std::move(dim), keepdim, "linear");
686-
}
687-
688-
Tensor quantile(
689-
const Tensor& self,
690-
double q,
691-
optional<int64_t> dim,
692-
bool keepdim) {
693-
// NOLINTNEXTLINE(performance-move-const-arg)
694-
return at::native::quantile(self, q, std::move(dim), keepdim, "linear");
695-
}
696-
697-
Tensor& nanquantile_out(
698-
const Tensor& self,
699-
const Tensor& q,
700-
optional<int64_t> dim,
701-
bool keepdim,
702-
Tensor& out) {
703-
// NOLINTNEXTLINE(performance-move-const-arg)
704-
return at::native::nanquantile_out(self, q, std::move(dim), keepdim, "linear", out);
705-
}
706-
707-
Tensor& nanquantile_out(
708-
const Tensor& self,
709-
double q,
710-
optional<int64_t> dim,
711-
bool keepdim,
712-
Tensor& out) {
713-
// NOLINTNEXTLINE(performance-move-const-arg)
714-
return at::native::nanquantile_out(self, q, std::move(dim), keepdim, "linear", out);
715-
}
716-
717-
Tensor nanquantile(
718-
const Tensor& self,
719-
const Tensor& q,
720-
optional<int64_t> dim,
721-
bool keepdim) {
722-
// NOLINTNEXTLINE(performance-move-const-arg)
723-
return at::native::nanquantile(self, q, std::move(dim), keepdim, "linear");
724-
}
725-
726-
Tensor nanquantile(
727-
const Tensor& self,
728-
double q,
729-
optional<int64_t> dim,
730-
bool keepdim) {
731-
// NOLINTNEXTLINE(performance-move-const-arg)
732-
return at::native::nanquantile(self, q, std::move(dim), keepdim, "linear");
647+
self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation);
733648
}
734649

735650
std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(

aten/src/ATen/native/native_functions.yaml

+8-31
Original file line numberDiff line numberDiff line change
@@ -7564,48 +7564,25 @@
75647564
device_check: NoCheck # TensorIterator
75657565
variants: method, function
75667566

7567-
# The following quantile signatures are DEPRECATED in favor of the new ones with the interpolation kwarg.
7568-
- func: quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
7569-
7570-
- func: quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False) -> Tensor
7571-
variants: method, function
7572-
7573-
- func: quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
7574-
7575-
- func: quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False) -> Tensor
7576-
variants: method, function
7577-
7578-
- func: nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
7579-
7580-
- func: nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False) -> Tensor
7581-
variants: method, function
7582-
7583-
- func: nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
7584-
7585-
- func: nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False) -> Tensor
7567+
- func: quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
75867568
variants: method, function
75877569

7588-
# To keep backward and forward compatibility, and to avoid ambiguity with the original signatures, dim, keepdim and interpolation
7589-
# parameters are required for now. Once the deprecated signatures are removed they will be made optional.
7590-
- func: quantile.new_scalar_out(Tensor self, float q, int? dim, bool keepdim, *, str interpolation, Tensor(a!) out) -> Tensor(a!)
7570+
- func: quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
75917571

7592-
- func: quantile.new_scalar(Tensor self, float q, int? dim, bool keepdim, *, str interpolation) -> Tensor
7572+
- func: quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
75937573
variants: method, function
75947574

7595-
- func: quantile.new_out(Tensor self, Tensor q, int? dim, bool keepdim, *, str interpolation, Tensor(a!) out) -> Tensor(a!)
7575+
- func: quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
75967576

7597-
- func: quantile.new(Tensor self, Tensor q, int? dim, bool keepdim, *, str interpolation) -> Tensor
7577+
- func: nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
75987578
variants: method, function
75997579

7600-
- func: nanquantile.new_scalar_out(Tensor self, float q, int? dim, bool keepdim, *, str interpolation, Tensor(a!) out) -> Tensor(a!)
7580+
- func: nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
76017581

7602-
- func: nanquantile.new_scalar(Tensor self, float q, int? dim, bool keepdim, *, str interpolation) -> Tensor
7582+
- func: nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
76037583
variants: method, function
76047584

7605-
- func: nanquantile.new_out(Tensor self, Tensor q, int? dim, bool keepdim, *, str interpolation, Tensor(a!) out) -> Tensor(a!)
7606-
7607-
- func: nanquantile.new(Tensor self, Tensor q, int? dim, bool keepdim, *, str interpolation) -> Tensor
7608-
variants: method, function
7585+
- func: nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
76097586

76107587
- func: sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
76117588
device_check: NoCheck # TensorIterator

test/backward_compatibility/check_backward_compatibility.py

+2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@
109109
("aten::_inverse_helper", datetime.date(2021, 12, 31)),
110110
("aten::softplus_backward", datetime.date(2022, 1, 31)),
111111
("aten::softplus_backward.grad_input", datetime.date(2022, 1, 31)),
112+
("aten::quantile", datetime.date(2022, 9, 30)),
113+
("aten::nanquantile", datetime.date(2022, 9, 30)),
112114
]
113115

114116
ALLOW_LIST_COMPILED = [

torch/_tensor_docs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2879,13 +2879,13 @@ def callable(a, b) -> number
28792879
""")
28802880

28812881
add_docstr_all('quantile', r"""
2882-
quantile(q, dim=None, keepdim=False) -> Tensor
2882+
quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor
28832883
28842884
See :func:`torch.quantile`
28852885
""")
28862886

28872887
add_docstr_all('nanquantile', r"""
2888-
nanquantile(q, dim=None, keepdim=False) -> Tensor
2888+
nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor
28892889
28902890
See :func:`torch.nanquantile`
28912891
""")

torch/_torch_docs.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -6306,16 +6306,20 @@ def merge_dicts(*dicts):
63066306
""".format(**single_dim_common))
63076307

63086308
add_docstr(torch.quantile, r"""
6309-
quantile(input, q, dim=None, keepdim=False, *, out=None) -> Tensor
6309+
quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor
63106310
6311-
Computes the q-th quantiles of each row of the :attr:`input` tensor
6312-
along the dimension :attr:`dim`.
6311+
Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`.
63136312
63146313
To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location
63156314
of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with
6316-
indices ``i`` and ``j`` in the sorted order, result is computed using linear interpolation as follows:
6315+
indices ``i`` and ``j`` in the sorted order, result is computed according to the given
6316+
:attr:`interpolation` method as follows:
63176317
6318-
``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index.
6318+
- ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index.
6319+
- ``lower``: ``a``.
6320+
- ``higher``: ``b``.
6321+
- ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions).
6322+
- ``midpoint``: ``(a + b) / 2``.
63196323
63206324
If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size
63216325
equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction.
@@ -6330,6 +6334,9 @@ def merge_dicts(*dicts):
63306334
{keepdim}
63316335
63326336
Keyword arguments:
6337+
interpolation (string): interpolation method to use when the desired quantile lies between two data points.
6338+
Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``.
6339+
Default is ``linear``.
63336340
{out}
63346341
63356342
Example::
@@ -6353,10 +6360,22 @@ def merge_dicts(*dicts):
63536360
>>> a = torch.arange(4.)
63546361
>>> a
63556362
tensor([0., 1., 2., 3.])
6363+
>>> torch.quantile(a, 0.6, interpolation='linear')
6364+
tensor(1.8000)
6365+
>>> torch.quantile(a, 0.6, interpolation='lower')
6366+
tensor(1.)
6367+
>>> torch.quantile(a, 0.6, interpolation='higher')
6368+
tensor(2.)
6369+
>>> torch.quantile(a, 0.6, interpolation='midpoint')
6370+
tensor(1.5000)
6371+
>>> torch.quantile(a, 0.6, interpolation='nearest')
6372+
tensor(2.)
6373+
>>> torch.quantile(a, 0.4, interpolation='nearest')
6374+
tensor(1.)
63566375
""".format(**single_dim_common))
63576376

63586377
add_docstr(torch.nanquantile, r"""
6359-
nanquantile(input, q, dim=None, keepdim=False, *, out=None) -> Tensor
6378+
nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor
63606379
63616380
This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values,
63626381
computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did
@@ -6370,6 +6389,9 @@ def merge_dicts(*dicts):
63706389
{keepdim}
63716390
63726391
Keyword arguments:
6392+
interpolation (string): interpolation method to use when the desired quantile lies between two data points.
6393+
Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``.
6394+
Default is ``linear``.
63736395
{out}
63746396
63756397
Example::

torch/overrides.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -843,8 +843,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
843843
torch.q_zero_point: lambda input: -1,
844844
torch.qr: lambda input, some=True, out=None: -1,
845845
torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
846-
torch.quantile: lambda input, q, dim=None, keepdim=False, out=None: -1,
847-
torch.nanquantile: lambda input, q, dim=None, keepdim=False, out=None: -1,
846+
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
847+
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
848848
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
849849
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
850850
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,

0 commit comments

Comments
 (0)