@@ -440,24 +440,25 @@ Status DispatchRadixSort(OpKernelContext* context, const int32_t size,
440
440
keys_out = mutable_keys_out;
441
441
}
442
442
443
- if (size <= KEYS_PER_ITEM * GROUP_SIZE) {
444
- using Rsortor = GroupRadixSortor<
445
- KeyT, /* key_per_item==*/ KEYS_PER_ITEM, /* group_size=*/ GROUP_SIZE,
446
- /* subgroup_size =*/ SUBGROUP_SIZE, sycl::group<1 >, ValueT>;
447
- // Compute the required local memory size
448
- size_t local_memory_size = Rsortor::LocalStorage::SIZE;
449
- const int32_t num_wg = 1 ;
450
- sycl::range<1 > global_range (num_wg * GROUP_SIZE);
451
- sycl::range<1 > local_range (GROUP_SIZE);
452
-
453
- return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
454
- Rsortor>(
455
- stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
456
- local_range, local_memory_size, num_bits);
457
- } else {
443
+ if (size > KEYS_PER_ITEM * GROUP_SIZE &&
444
+ !std::is_floating_point_v<KeyT>) { // DeviceRadixSort will write OOM for
445
+ // float/double point types.
458
446
return DispatchDeviceRadixSort (context, keys_in, indices_in, keys_out,
459
447
indices_out, size);
460
448
}
449
+ using Rsortor = GroupRadixSortor<
450
+ KeyT, /* key_per_item==*/ KEYS_PER_ITEM, /* group_size=*/ GROUP_SIZE,
451
+ /* subgroup_size =*/ SUBGROUP_SIZE, sycl::group<1 >, ValueT>;
452
+ // Compute the required local memory size
453
+ size_t local_memory_size = Rsortor::LocalStorage::SIZE;
454
+ const int32_t num_wg = 1 ;
455
+ sycl::range<1 > global_range (num_wg * GROUP_SIZE);
456
+ sycl::range<1 > local_range (GROUP_SIZE);
457
+
458
+ return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
459
+ Rsortor>(
460
+ stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
461
+ local_range, local_memory_size, num_bits);
461
462
}
462
463
463
464
template <typename InputIteratorT, typename OutputIteratorT, typename BinaryOp>
0 commit comments