@@ -3730,6 +3730,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
37303730 }
37313731}
37323732
3733+ static void ggml_vk_buffer_memset_async (vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
3734+ VK_LOG_DEBUG (" ggml_vk_buffer_memset_async(" << offset << " , " << c << " , " << size << " )" );
3735+
3736+ ctx->s ->buffer .fillBuffer (dst->buffer , offset, size, c);
3737+ }
3738+
37333739static void ggml_vk_buffer_memset (vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
37343740 VK_LOG_DEBUG (" ggml_vk_buffer_memset(" << offset << " , " << c << " , " << size << " )" );
37353741
@@ -5717,6 +5723,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57175723 // im2col uses only src1 and dst buffers
57185724 ggml_vk_sync_buffers (subctx);
57195725 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof (PC), &pc, elements);
5726+ } else if (op == GGML_OP_COUNT_EQUAL) {
5727+ ggml_vk_sync_buffers (subctx);
5728+ // count_equal assumes that destination buffer is initialized with zeroes
5729+ ggml_vk_buffer_memset_async (subctx, d_D, d_buf_offset, 0 , d_sz);
5730+ ggml_vk_sync_buffers (subctx);
5731+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof (PC), &pc, elements);
57205732 } else if (use_src2) {
57215733 ggml_vk_sync_buffers (subctx);
57225734 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof (PC), &pc, elements);
@@ -6331,7 +6343,6 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
63316343}
63326344
63336345static void ggml_vk_count_equal (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6334- ggml_backend_tensor_memset (dst, 0 , 0 , ggml_nbytes (dst));
63356346 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_COUNT_EQUAL, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
63366347}
63376348
0 commit comments