@@ -251,6 +251,7 @@ struct vk_device_struct {
251251 vk_pipeline pipeline_diag_mask_inf_f32;
252252 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
253253 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
254+ vk_pipeline pipeline_soft_max_back_f32;
254255 vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
255256 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
256257 vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -2188,6 +2189,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21882189 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_wg512 , " soft_max_f32_wg512" , soft_max_f32_len, soft_max_f32_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
21892190 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16 , " soft_max_f32_f16" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
21902191 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16_wg512 , " soft_max_f32_f16_wg512" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
2192+ ggml_vk_create_pipeline (device, device->pipeline_soft_max_back_f32 , " soft_max_back_f32" , soft_max_back_f32_len, soft_max_back_f32_data, " main" , 3 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
21912193
21922194 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f32 , " rope_norm_f32" , rope_norm_f32_len, rope_norm_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21932195 ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f32 , " rope_neox_f32" , rope_neox_f32_len, rope_neox_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
@@ -5330,6 +5332,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53305332 return src0->ne [0 ] > 1024 ? ctx->device ->pipeline_soft_max_f32_f16_wg512 : ctx->device ->pipeline_soft_max_f32_f16 ;
53315333 }
53325334 return nullptr ;
5335+ case GGML_OP_SOFT_MAX_BACK:
5336+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5337+ return ctx->device ->pipeline_soft_max_back_f32 ;
5338+ }
5339+ return nullptr ;
53335340 case GGML_OP_ROPE:
53345341 case GGML_OP_ROPE_BACK:
53355342 {
@@ -5643,6 +5650,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56435650 case GGML_OP_RMS_NORM:
56445651 case GGML_OP_RMS_NORM_BACK:
56455652 case GGML_OP_SOFT_MAX:
5653+ case GGML_OP_SOFT_MAX_BACK:
56465654 case GGML_OP_SUM_ROWS:
56475655 {
56485656 const uint32_t nr = ggml_nrows (src0);
@@ -6203,6 +6211,11 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
62036211 }, dryrun);
62046212}
62056213
6214+ static void ggml_vk_soft_max_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6215+ float * op_params = (float *)dst->op_params ;
6216+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], op_params[0 ], op_params[1 ] }, dryrun);
6217+ }
6218+
62066219static void ggml_vk_rope (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false ) {
62076220 const int n_dims = ((int32_t *) dst->op_params )[1 ];
62086221 const int mode = ((int32_t *) dst->op_params )[2 ];
@@ -7145,6 +7158,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71457158 case GGML_OP_RMS_NORM_BACK:
71467159 case GGML_OP_DIAG_MASK_INF:
71477160 case GGML_OP_SOFT_MAX:
7161+ case GGML_OP_SOFT_MAX_BACK:
71487162 case GGML_OP_ROPE:
71497163 case GGML_OP_ROPE_BACK:
71507164 case GGML_OP_MUL_MAT:
@@ -7201,6 +7215,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72017215 case GGML_OP_UNARY:
72027216 case GGML_OP_DIAG_MASK_INF:
72037217 case GGML_OP_SOFT_MAX:
7218+ case GGML_OP_SOFT_MAX_BACK:
72047219 case GGML_OP_ROPE:
72057220 case GGML_OP_ROPE_BACK:
72067221 case GGML_OP_ARGSORT:
@@ -7324,6 +7339,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73247339 case GGML_OP_SOFT_MAX:
73257340 ggml_vk_soft_max (ctx, compute_ctx, src0, src1, node, dryrun);
73267341
7342+ break ;
7343+ case GGML_OP_SOFT_MAX_BACK:
7344+ ggml_vk_soft_max_back (ctx, compute_ctx, src0, src1, node, dryrun);
7345+
73277346 break ;
73287347 case GGML_OP_ROPE:
73297348 ggml_vk_rope (ctx, compute_ctx, src0, src1, src2, node, false , dryrun);
@@ -7445,6 +7464,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74457464 case GGML_OP_RMS_NORM_BACK:
74467465 case GGML_OP_DIAG_MASK_INF:
74477466 case GGML_OP_SOFT_MAX:
7467+ case GGML_OP_SOFT_MAX_BACK:
74487468 case GGML_OP_ROPE:
74497469 case GGML_OP_ROPE_BACK:
74507470 case GGML_OP_RESHAPE:
@@ -8376,6 +8396,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83768396 case GGML_OP_PAD:
83778397 case GGML_OP_DIAG_MASK_INF:
83788398 case GGML_OP_SOFT_MAX:
8399+ case GGML_OP_SOFT_MAX_BACK:
83798400 case GGML_OP_ARGSORT:
83808401 case GGML_OP_SUM_ROWS:
83818402 case GGML_OP_IM2COL:
@@ -8901,6 +8922,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89018922 } else {
89028923 tensor_clone = ggml_soft_max (ggml_ctx, src0_clone);
89038924 }
8925+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
8926+ tensor_clone = ggml_soft_max_ext_back (ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params )[0 ], ((float *)tensor->op_params )[1 ]);
89048927 } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
89058928 tensor_clone = ggml_diag_mask_inf (ggml_ctx, src0_clone, *(int *)tensor->op_params );
89068929 } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
0 commit comments