Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Legend:
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
Expand Down
8,140 changes: 7 additions & 8,133 deletions docs/ops/SYCL.csv

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions docs/ops/_ALL.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
"SYCL0","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","1","yes","SYCL"
"SYCL0","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","1","yes","SYCL"
"SYCL0","COUNT_EQUAL","type=f32,ne=[1024,1,1,1]","support","1","yes","SYCL"
"SYCL0","COUNT_EQUAL","type=f32,ne=[64,64,1,1]","support","1","yes","SYCL"
"SYCL0","COUNT_EQUAL","type=f16,ne=[256,32,1,1]","support","1","yes","SYCL"
"SYCL0","COUNT_EQUAL","type=i32,ne=[512,16,1,1]","support","1","yes","SYCL"
"SYCL0","COUNT_EQUAL","type=i16,ne=[512,16,1,1]","support","1","yes","SYCL"
11 changes: 11 additions & 0 deletions ggml/src/ggml-sycl/binbcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}

inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
}


inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {

ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
Expand All @@ -327,6 +332,12 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_sub(ctx, dst);
}

void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_count_equal(ctx, dst);
}


void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_mul(ctx, dst);
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-sycl/binbcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}

static __dpct_inline__ float op_count_equal(const float a, const float b) {
return (a == b) ? 1.0f : 0.0f;
}

void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);


static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}
Expand Down
8 changes: 6 additions & 2 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3570,7 +3570,11 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SUB:
ggml_sycl_sub(ctx, dst);
break;
case GGML_OP_ACC:

case GGML_OP_COUNT_EQUAL:
ggml_sycl_count_equal(ctx, dst);
break;
case GGML_OP_ACC:
ggml_sycl_acc(ctx, dst);
break;
case GGML_OP_MUL:
Expand Down Expand Up @@ -4063,7 +4067,6 @@ static ggml_backend_i ggml_backend_sycl_interface = {
/* .graph_compute = */ ggml_backend_sycl_graph_compute,
/* .event_record = */ ggml_backend_sycl_event_record,
/* .event_wait = */ ggml_backend_sycl_event_wait,
/* .optimize_graph = */ NULL,
};

static ggml_guid_t ggml_backend_sycl_guid() {
Expand Down Expand Up @@ -4349,6 +4352,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_REPEAT:
Expand Down
31 changes: 31 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2237,6 +2237,31 @@ struct test_count_equal : public test_case {
};

// GGML_OP_REPEAT

/* COUNT_EQUAL – typed test (no argmax), to cover F32/F16/I32/I16 */
struct test_count_equal_typed : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;

test_count_equal_typed(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {128, 64, 1, 1})
: type(type), ne(ne) {}

std::string vars() override {
return VARS_TO_STR2(type, ne);
}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
ggml_tensor * out = ggml_count_equal(ctx, a, b);
ggml_set_name(out, "out");
return out;
}
};

struct test_repeat : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
Expand Down Expand Up @@ -5940,6 +5965,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {

test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
// COUNT_EQUAL – typed tests by dtype
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F32, {1024, 1, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F32, { 64, 64, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F16, { 256, 32, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_I32, { 512, 16, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_I16, { 512, 16, 1, 1}));

test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 513, 1, 1}));
Expand Down