Skip to content

Commit 6467251

Browse files
authored
Add index check for embedding kernel
Differential Revision: D75982682 Pull Request resolved: #11375
1 parent c120b35 commit 6467251

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

kernels/quantized/cpu/op_embedding.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,22 @@ void embedding_byte_per_channel(
153153

154154
for (int i = 0; i < indices.numel(); i++) {
155155
int64_t index = indices_ptr[i];
156+
157+
// Check if index is out of bounds for both weight and weight_scales
158+
ET_CHECK_MSG(
159+
index >= 0 && index < weight.size(0),
160+
"Index out of bounds for weight: index %" PRId64
161+
" must be in range [0, %zd)",
162+
index,
163+
weight.size(0));
164+
165+
ET_CHECK_MSG(
166+
index >= 0 && index < weight_scales.size(0),
167+
"Index out of bounds for weight_scales: index %" PRId64
168+
" must be in range [0, %zd)",
169+
index,
170+
weight_scales.size(0));
171+
156172
// If using groupwise embedding
157173
int32_t qparams_index = index * num_groups_per_channel;
158174
CTYPE_PARAMS zp = 0.0;

kernels/quantized/test/op_embedding_test.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,38 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
373373
out),
374374
"");
375375
}
376+
377+
TEST(OpQuantizedEmbeddingTest, TestOutOfBoundsIndex) {
378+
et_pal_init();
379+
TensorFactory<ScalarType::Float> tf;
380+
TensorFactory<ScalarType::Long> tf_l;
381+
382+
int64_t quant_min = 0;
383+
int64_t quant_max = 255;
384+
385+
// Create a weight tensor with 3 rows
386+
TensorFactory<ScalarType::Byte> tfo;
387+
Tensor qweight =
388+
tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
389+
390+
// Create weight_scales with the same number of rows
391+
Tensor weight_scales = tf.make({3, 1}, {0.5, 1.0, 1.5});
392+
Tensor weight_zero_points = tf.make({3, 1}, {1, 5, 7});
393+
394+
// Create indices with an out-of-bounds index (3, which is >= weight.size(0))
395+
Tensor indices = tf_l.make({2}, {1, 3});
396+
397+
Tensor out = tf.zeros({2, 4});
398+
399+
// Expect death when accessing an out-of-bounds index
400+
ET_EXPECT_DEATH(
401+
quantized_embedding_byte_out(
402+
qweight,
403+
weight_scales,
404+
weight_zero_points,
405+
quant_min,
406+
quant_max,
407+
indices,
408+
out),
409+
"Index out of bounds for weight");
410+
}

0 commit comments

Comments
 (0)