@@ -17,7 +17,8 @@ at::Tensor cross_entropy(
17
17
c10::optional<bool >& size_average, // TODO: use
18
18
c10::optional<int64_t >& ignore_index,
19
19
c10::optional<bool >& reduce, // TODO: use
20
- c10::optional<std::string>& reduction) {
20
+ c10::optional<std::string>& reduction,
21
+ c10::optional<double > label_smoothing) {
21
22
F::CrossEntropyFuncOptions::reduction_t redct;
22
23
if (reduction.value () == " mean" || reduction.value () == " none" ) {
23
24
redct = torch::kMean ;
@@ -32,6 +33,9 @@ at::Tensor cross_entropy(
32
33
if (ignore_index.has_value ()) {
33
34
options = options.ignore_index (ignore_index.value ());
34
35
}
36
+ if (label_smoothing.has_value ()) {
37
+ options = options.label_smoothing (label_smoothing.value ());
38
+ }
35
39
36
40
return map_nested_tensor (
37
41
[&, options](at::Tensor input_tensor, at::Tensor target_tensor) {
@@ -244,23 +248,26 @@ void add_functions(pybind11::module m) {
244
248
c10::optional<bool > size_average, // TODO: use
245
249
c10::optional<int64_t > ignore_index,
246
250
c10::optional<bool > reduce, // TODO: use
247
- c10::optional<std::string> reduction) {
251
+ c10::optional<std::string> reduction,
252
+ c10::optional<double > label_smoothing) {
248
253
return cross_entropy (
249
254
input,
250
255
target,
251
256
weight,
252
257
size_average,
253
258
ignore_index,
254
259
reduce,
255
- reduction);
260
+ reduction,
261
+ label_smoothing);
256
262
},
257
263
py::arg (" input" ),
258
264
py::arg (" target" ),
259
265
py::arg (" weight" ) = nullptr ,
260
266
py::arg (" size_average" ) = true ,
261
267
py::arg (" ignore_index" ) = -100 ,
262
268
py::arg (" reduce" ) = true ,
263
- py::arg (" reduction" ) = " mean" );
269
+ py::arg (" reduction" ) = " mean" ,
270
+ py::arg (" label_smoothing" ) = 0.0 );
264
271
}
265
272
} // namespace nested_tensor
266
273
} // namespace torch
0 commit comments