@@ -232,7 +232,7 @@ struct vk_device_struct {
232232 vk_pipeline pipeline_cos_f32;
233233 vk_pipeline pipeline_clamp_f32;
234234 vk_pipeline pipeline_pad_f32;
235- vk_pipeline pipeline_repeat_f32;
235+ vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32 ;
236236 vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
237237 vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
238238 vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -2127,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21272127 ggml_vk_create_pipeline (device, device->pipeline_pad_f32 , " pad_f32" , pad_f32_len, pad_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
21282128
21292129 ggml_vk_create_pipeline (device, device->pipeline_repeat_f32 , " repeat_f32" , repeat_f32_len, repeat_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2130+ ggml_vk_create_pipeline (device, device->pipeline_repeat_back_f32 , " repeat_back_f32" , repeat_back_f32_len, repeat_back_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
21302131
21312132 ggml_vk_create_pipeline (device, device->pipeline_gelu_f32 , " gelu_f32" , gelu_f32_len, gelu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21322133 ggml_vk_create_pipeline (device, device->pipeline_gelu_quick_f32 , " gelu_quick_f32" , gelu_quick_f32_len, gelu_quick_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -5201,6 +5202,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52015202 return ctx->device ->pipeline_repeat_f32 ;
52025203 }
52035204 return nullptr ;
5205+ case GGML_OP_REPEAT_BACK:
5206+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5207+ return ctx->device ->pipeline_repeat_back_f32 ;
5208+ }
5209+ return nullptr ;
52045210 case GGML_OP_CPY:
52055211 case GGML_OP_CONT:
52065212 case GGML_OP_DUP:
@@ -5365,6 +5371,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
53655371 case GGML_OP_CLAMP:
53665372 case GGML_OP_PAD:
53675373 case GGML_OP_REPEAT:
5374+ case GGML_OP_REPEAT_BACK:
53685375 return true ;
53695376 default :
53705377 return false ;
@@ -5649,6 +5656,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56495656 case GGML_OP_CLAMP:
56505657 case GGML_OP_PAD:
56515658 case GGML_OP_REPEAT:
5659+ case GGML_OP_REPEAT_BACK:
56525660 case GGML_OP_CPY:
56535661 case GGML_OP_CONCAT:
56545662 case GGML_OP_UPSCALE:
@@ -6182,6 +6190,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
61826190 }, dryrun);
61836191}
61846192
6193+ static void ggml_vk_repeat_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6194+ const uint32_t src0_type_size = ggml_type_size (src0->type );
6195+ const uint32_t dst_type_size = ggml_type_size (dst->type );
6196+
6197+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_REPEAT_BACK, {
6198+ (uint32_t )ggml_nelements (dst),
6199+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ], (uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )src0->nb [1 ] / src0_type_size, (uint32_t )src0->nb [2 ] / src0_type_size, (uint32_t )src0->nb [3 ] / src0_type_size,
6200+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ], (uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t ) dst->nb [1 ] / dst_type_size, (uint32_t ) dst->nb [2 ] / dst_type_size, (uint32_t ) dst->nb [3 ] / dst_type_size,
6201+ 0 ,
6202+ 0 .0f , 0 .0f ,
6203+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
6204+ }, dryrun);
6205+ }
6206+
61856207static void ggml_vk_cpy (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
61866208 const uint32_t src0_type_size = ggml_type_size (src0->type );
61876209 const uint32_t dst_type_size = ggml_type_size (dst->type );
@@ -7177,6 +7199,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71777199 }
71787200 break ;
71797201 case GGML_OP_REPEAT:
7202+ case GGML_OP_REPEAT_BACK:
71807203 case GGML_OP_GET_ROWS:
71817204 case GGML_OP_ADD:
71827205 case GGML_OP_ACC:
@@ -7234,6 +7257,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72347257 } else {
72357258 switch (node->op ) {
72367259 case GGML_OP_REPEAT:
7260+ case GGML_OP_REPEAT_BACK:
72377261 case GGML_OP_ACC:
72387262 case GGML_OP_GET_ROWS:
72397263 case GGML_OP_ADD:
@@ -7283,6 +7307,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72837307 case GGML_OP_REPEAT:
72847308 ggml_vk_repeat (ctx, compute_ctx, src0, node, dryrun);
72857309
7310+ break ;
7311+ case GGML_OP_REPEAT_BACK:
7312+ ggml_vk_repeat_back (ctx, compute_ctx, src0, node, dryrun);
7313+
72867314 break ;
72877315 case GGML_OP_ACC:
72887316 ggml_vk_acc (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7528,6 +7556,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
75287556 case GGML_OP_RWKV_WKV6:
75297557 case GGML_OP_LEAKY_RELU:
75307558 case GGML_OP_REPEAT:
7559+ case GGML_OP_REPEAT_BACK:
75317560 case GGML_OP_OPT_STEP_ADAMW:
75327561 buf = tensor->buffer ;
75337562
@@ -8420,6 +8449,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84208449 } break ;
84218450 case GGML_OP_REPEAT:
84228451 return ggml_type_size (op->type ) == sizeof (float ) && ggml_type_size (op->src [0 ]->type ) == sizeof (float );
8452+ case GGML_OP_REPEAT_BACK:
8453+ return op->type == GGML_TYPE_F32 && op->src [0 ]->type == GGML_TYPE_F32;
84238454 case GGML_OP_ROPE:
84248455 {
84258456 const int mode = ((const int32_t *) op->op_params )[2 ];
@@ -8830,6 +8861,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88308861 tensor_clone = ggml_pad (ggml_ctx, src_clone[0 ], tensor->ne [0 ] - src_clone[0 ]->ne [0 ], tensor->ne [1 ] - src_clone[0 ]->ne [1 ], tensor->ne [2 ] - src_clone[0 ]->ne [2 ], tensor->ne [3 ] - src_clone[0 ]->ne [3 ]);
88318862 } else if (tensor->op == GGML_OP_REPEAT) {
88328863 tensor_clone = ggml_repeat (ggml_ctx, src_clone[0 ], tensor);
8864+ } else if (tensor->op == GGML_OP_REPEAT_BACK) {
8865+ tensor_clone = ggml_repeat_back (ggml_ctx, src_clone[0 ], tensor);
88338866 } else if (tensor->op == GGML_OP_ADD) {
88348867 tensor_clone = ggml_add (ggml_ctx, src_clone[0 ], src_clone[1 ]);
88358868 } else if (tensor->op == GGML_OP_ACC) {
0 commit comments