Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit a6faebf

Browse files
jbschlosserfacebook-github-bot
authored andcommitted
Support label_smoothing for cross_entropy in nestedtensor (#452)
Summary: Pull Request resolved: #452 Adds awareness of the new `label_smoothing` functionality in `cross_entropy` added in pytorch/pytorch#63122 to nestedtensor's implementation. Fixes broken test: `test.test_nested_tensor_functional.TestFunctional`. Reviewed By: cpuhrsch Differential Revision: D30730728 fbshipit-source-id: 04f146d6de7f764f165059e4b5654d7f39142e38
1 parent 4cc2a37 commit a6faebf

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

nestedtensor/csrc/python_functions.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ at::Tensor cross_entropy(
1717
c10::optional<bool>& size_average, // TODO: use
1818
c10::optional<int64_t>& ignore_index,
1919
c10::optional<bool>& reduce, // TODO: use
20-
c10::optional<std::string>& reduction) {
20+
c10::optional<std::string>& reduction,
21+
c10::optional<double> label_smoothing) {
2122
F::CrossEntropyFuncOptions::reduction_t redct;
2223
if (reduction.value() == "mean" || reduction.value() == "none") {
2324
redct = torch::kMean;
@@ -32,6 +33,9 @@ at::Tensor cross_entropy(
3233
if (ignore_index.has_value()) {
3334
options = options.ignore_index(ignore_index.value());
3435
}
36+
if (label_smoothing.has_value()) {
37+
options = options.label_smoothing(label_smoothing.value());
38+
}
3539

3640
return map_nested_tensor(
3741
[&, options](at::Tensor input_tensor, at::Tensor target_tensor) {
@@ -244,23 +248,26 @@ void add_functions(pybind11::module m) {
244248
c10::optional<bool> size_average, // TODO: use
245249
c10::optional<int64_t> ignore_index,
246250
c10::optional<bool> reduce, // TODO: use
247-
c10::optional<std::string> reduction) {
251+
c10::optional<std::string> reduction,
252+
c10::optional<double> label_smoothing) {
248253
return cross_entropy(
249254
input,
250255
target,
251256
weight,
252257
size_average,
253258
ignore_index,
254259
reduce,
255-
reduction);
260+
reduction,
261+
label_smoothing);
256262
},
257263
py::arg("input"),
258264
py::arg("target"),
259265
py::arg("weight") = nullptr,
260266
py::arg("size_average") = true,
261267
py::arg("ignore_index") = -100,
262268
py::arg("reduce") = true,
263-
py::arg("reduction") = "mean");
269+
py::arg("reduction") = "mean",
270+
py::arg("label_smoothing") = 0.0);
264271
}
265272
} // namespace nested_tensor
266273
} // namespace torch

0 commit comments

Comments
 (0)