Skip to content

Commit 7c79153

Browse files
authored
TruncatedMod and Unique Op fixes (#2750)
1 parent 5961809 commit 7c79153

File tree

3 files changed

+20
-22
lines changed

3 files changed

+20
-22
lines changed

itex/core/kernels/gpu/cwise_op_mod.cc

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ limitations under the License.
2020

2121
namespace itex {
2222

23-
REGISTER2(BinaryOp, GPU, "TruncateMod", functor::safe_mod, int32, int64);
2423
REGISTER3(BinaryOp, GPU, "TruncateMod", functor::fmod, float, Eigen::bfloat16,
2524
Eigen::half);
2625

itex/core/kernels/gpu/unique_op.h

+16-15
Original file line numberDiff line numberDiff line change
@@ -440,24 +440,25 @@ Status DispatchRadixSort(OpKernelContext* context, const int32_t size,
440440
keys_out = mutable_keys_out;
441441
}
442442

443-
if (size <= KEYS_PER_ITEM * GROUP_SIZE) {
444-
using Rsortor = GroupRadixSortor<
445-
KeyT, /*key_per_item==*/KEYS_PER_ITEM, /*group_size=*/GROUP_SIZE,
446-
/*subgroup_size =*/SUBGROUP_SIZE, sycl::group<1>, ValueT>;
447-
// Compute the required local memory size
448-
size_t local_memory_size = Rsortor::LocalStorage::SIZE;
449-
const int32_t num_wg = 1;
450-
sycl::range<1> global_range(num_wg * GROUP_SIZE);
451-
sycl::range<1> local_range(GROUP_SIZE);
452-
453-
return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
454-
Rsortor>(
455-
stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
456-
local_range, local_memory_size, num_bits);
457-
} else {
443+
if (size > KEYS_PER_ITEM * GROUP_SIZE &&
444+
!std::is_floating_point_v<KeyT>) { // DeviceRadixSort will write OOM for
445+
// float/double point types.
458446
return DispatchDeviceRadixSort(context, keys_in, indices_in, keys_out,
459447
indices_out, size);
460448
}
449+
using Rsortor = GroupRadixSortor<
450+
KeyT, /*key_per_item==*/KEYS_PER_ITEM, /*group_size=*/GROUP_SIZE,
451+
/*subgroup_size =*/SUBGROUP_SIZE, sycl::group<1>, ValueT>;
452+
// Compute the required local memory size
453+
size_t local_memory_size = Rsortor::LocalStorage::SIZE;
454+
const int32_t num_wg = 1;
455+
sycl::range<1> global_range(num_wg * GROUP_SIZE);
456+
sycl::range<1> local_range(GROUP_SIZE);
457+
458+
return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
459+
Rsortor>(
460+
stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
461+
local_range, local_memory_size, num_bits);
461462
}
462463

463464
template <typename InputIteratorT, typename OutputIteratorT, typename BinaryOp>

test/benchmark/test_TruncateMod.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,17 @@
2424

2525
try:
2626
from intel_extension_for_tensorflow.python.test_func import test
27-
INT_COMPUTE_TYPE = [dtypes.int32, dtypes.int64]
27+
INT_COMPUTE_TYPE = [dtypes.int32]
2828
except ImportError:
2929
from tensorflow.python.platform import test
30-
INT_COMPUTE_TYPE = [dtypes.int32, dtypes.int64]
30+
INT_COMPUTE_TYPE = [dtypes.int32]
3131

3232
ITERATION = 5
3333

3434
class TruncateModTest(test.TestCase):
3535
def _test_impl(self, x_size, y_size, dtype):
36-
x = np.random.normal(size=x_size)
37-
x = constant_op.constant(x, dtype=dtype)
38-
y = np.random.normal(size=x_size)
39-
y = constant_op.constant(y, dtype=dtype)
36+
x = tf.random.uniform(shape=x_size, minval=0, maxval=100, dtype=dtype)
37+
y = tf.random.uniform(shape=y_size, minval=1, maxval=100, dtype=dtype)
4038
flush_cache()
4139
out_gpu = tf.raw_ops.TruncateMod(x=x, y=y)
4240

0 commit comments

Comments
 (0)