@@ -554,6 +554,7 @@ struct vk_device_struct {
554
554
vk_pipeline pipeline_argmax_f32;
555
555
vk_pipeline pipeline_count_equal_i32;
556
556
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
557
+ vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
557
558
vk_pipeline pipeline_timestep_embedding_f32;
558
559
vk_pipeline pipeline_conv_transpose_1d_f32;
559
560
vk_pipeline pipeline_pool2d_f32;
@@ -982,6 +983,37 @@ struct vk_op_im2col_push_constants {
982
983
int32_t d0; int32_t d1;
983
984
};
984
985
986
+ struct vk_op_im2col_3d_push_constants {
987
+ uint32_t nb10;
988
+ uint32_t nb11;
989
+ uint32_t nb12;
990
+ uint32_t nb13;
991
+ uint32_t s0;
992
+ uint32_t s1;
993
+ uint32_t s2;
994
+ uint32_t p0;
995
+ uint32_t p1;
996
+ uint32_t p2;
997
+ uint32_t d0;
998
+ uint32_t d1;
999
+ uint32_t d2;
1000
+ uint32_t IW;
1001
+ uint32_t IH;
1002
+ uint32_t ID;
1003
+ uint32_t IC;
1004
+ uint32_t KW;
1005
+ uint32_t OH;
1006
+ uint32_t KD_KH_KW;
1007
+ uint32_t KH_KW;
1008
+ uint32_t IC_KD_KH_KW;
1009
+ uint32_t N_OD_OH;
1010
+ uint32_t OD_OH;
1011
+ uint32_t OD_OH_OW_IC_KD_KH_KW;
1012
+ uint32_t OH_OW_IC_KD_KH_KW;
1013
+ uint32_t OW_IC_KD_KH_KW;
1014
+ uint32_t misalign_offsets;
1015
+ };
1016
+
985
1017
struct vk_op_timestep_embedding_push_constants {
986
1018
uint32_t nb1;
987
1019
uint32_t dim;
@@ -3380,10 +3412,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
3380
3412
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
3381
3413
3382
3414
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3415
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3383
3416
if (device->float_controls_rte_fp16) {
3384
3417
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3418
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3385
3419
} else {
3386
3420
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3421
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3387
3422
}
3388
3423
3389
3424
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -7717,6 +7752,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7717
7752
return ctx->device->pipeline_im2col_f32_f16;
7718
7753
}
7719
7754
return nullptr;
7755
+ case GGML_OP_IM2COL_3D:
7756
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7757
+ return ctx->device->pipeline_im2col_3d_f32;
7758
+ }
7759
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
7760
+ return ctx->device->pipeline_im2col_3d_f32_f16;
7761
+ }
7762
+ return nullptr;
7720
7763
case GGML_OP_TIMESTEP_EMBEDDING:
7721
7764
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7722
7765
return ctx->device->pipeline_timestep_embedding_f32;
@@ -7832,6 +7875,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
7832
7875
case GGML_OP_RMS_NORM:
7833
7876
case GGML_OP_CONV_2D_DW:
7834
7877
case GGML_OP_IM2COL:
7878
+ case GGML_OP_IM2COL_3D:
7835
7879
case GGML_OP_SET_ROWS:
7836
7880
case GGML_OP_SUM:
7837
7881
case GGML_OP_SUM_ROWS:
@@ -7890,6 +7934,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
7890
7934
GGML_UNUSED(src2);
7891
7935
}
7892
7936
7937
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7938
+ const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
7939
+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
7940
+
7941
+ p.misalign_offsets = (a_offset << 16) | d_offset;
7942
+
7943
+ GGML_UNUSED(src0);
7944
+ GGML_UNUSED(src2);
7945
+ }
7946
+
7893
7947
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7894
7948
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
7895
7949
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -8130,6 +8184,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
8130
8184
8131
8185
elements = { OW * KW * KH, OH, batch * IC };
8132
8186
} break;
8187
+ case GGML_OP_IM2COL_3D:
8188
+ {
8189
+ const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
8190
+
8191
+ const uint32_t N = ne13 / IC;
8192
+
8193
+ const uint32_t KD = ne02;
8194
+ const uint32_t KH = ne01;
8195
+ const uint32_t KW = ne00;
8196
+
8197
+ const uint32_t OD = ned3 / N;
8198
+ const uint32_t OH = ned2;
8199
+ const uint32_t OW = ned1;
8200
+
8201
+ const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
8202
+ const uint32_t N_OD_OH = N*OD*OH;
8203
+
8204
+ elements = { IC_KD_KH_KW, OW, N_OD_OH };
8205
+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
8206
+ } break;
8133
8207
case GGML_OP_TIMESTEP_EMBEDDING:
8134
8208
{
8135
8209
const uint32_t dim = dst->op_params[0];
@@ -8286,7 +8360,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
8286
8360
}
8287
8361
8288
8362
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8289
- } else if (op == GGML_OP_IM2COL) {
8363
+ } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D ) {
8290
8364
// im2col uses only src1 and dst buffers
8291
8365
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8292
8366
} else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9147,6 +9221,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
9147
9221
}, dryrun);
9148
9222
}
9149
9223
9224
+ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9225
+ GGML_TENSOR_BINARY_OP_LOCALS
9226
+
9227
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
9228
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
9229
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
9230
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
9231
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
9232
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
9233
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
9234
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
9235
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
9236
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
9237
+
9238
+ const int64_t N = ne13 / IC;
9239
+ const int64_t ID = ne12;
9240
+ const int64_t IH = ne11;
9241
+ const int64_t IW = ne10;
9242
+
9243
+ const int64_t KD = ne02;
9244
+ const int64_t KH = ne01;
9245
+ const int64_t KW = ne00;
9246
+
9247
+ const int64_t OD = ne3 / N;
9248
+ const int64_t OH = ne2;
9249
+ const int64_t OW = ne1;
9250
+
9251
+ vk_op_im2col_3d_push_constants pc {};
9252
+
9253
+ pc.nb10 = nb10 / ggml_type_size(src1->type);
9254
+ pc.nb11 = nb11 / ggml_type_size(src1->type);
9255
+ pc.nb12 = nb12 / ggml_type_size(src1->type);
9256
+ pc.nb13 = nb13 / ggml_type_size(src1->type);
9257
+ pc.s0 = s0;
9258
+ pc.s1 = s1;
9259
+ pc.s2 = s2;
9260
+ pc.p0 = p0;
9261
+ pc.p1 = p1;
9262
+ pc.p2 = p2;
9263
+ pc.d0 = d0;
9264
+ pc.d1 = d1;
9265
+ pc.d2 = d2;
9266
+ pc.IW = IW;
9267
+ pc.IH = IH;
9268
+ pc.ID = ID;
9269
+ pc.IC = IC;
9270
+ pc.KW = KW;
9271
+ pc.OH = OH;
9272
+ pc.KD_KH_KW = KD*KH*KW;
9273
+ pc.KH_KW = KH*KW;
9274
+ pc.IC_KD_KH_KW = IC*KD*KH*KW;
9275
+ pc.N_OD_OH = N*OD*OH;
9276
+ pc.OD_OH = OD*OH;
9277
+ pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
9278
+ pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
9279
+ pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
9280
+
9281
+ ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
9282
+ }
9283
+
9150
9284
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
9151
9285
const uint32_t dim = dst->op_params[0];
9152
9286
const uint32_t max_period = dst->op_params[1];
@@ -10352,6 +10486,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10352
10486
case GGML_OP_ARGMAX:
10353
10487
case GGML_OP_COUNT_EQUAL:
10354
10488
case GGML_OP_IM2COL:
10489
+ case GGML_OP_IM2COL_3D:
10355
10490
case GGML_OP_TIMESTEP_EMBEDDING:
10356
10491
case GGML_OP_CONV_TRANSPOSE_1D:
10357
10492
case GGML_OP_POOL_2D:
@@ -10422,6 +10557,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10422
10557
case GGML_OP_ARGMAX:
10423
10558
case GGML_OP_COUNT_EQUAL:
10424
10559
case GGML_OP_IM2COL:
10560
+ case GGML_OP_IM2COL_3D:
10425
10561
case GGML_OP_TIMESTEP_EMBEDDING:
10426
10562
case GGML_OP_CONV_TRANSPOSE_1D:
10427
10563
case GGML_OP_POOL_2D:
@@ -10717,6 +10853,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10717
10853
case GGML_OP_IM2COL:
10718
10854
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
10719
10855
10856
+ break;
10857
+ case GGML_OP_IM2COL_3D:
10858
+ ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun);
10859
+
10720
10860
break;
10721
10861
case GGML_OP_TIMESTEP_EMBEDDING:
10722
10862
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
@@ -10868,6 +11008,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
10868
11008
case GGML_OP_ARGMAX:
10869
11009
case GGML_OP_COUNT_EQUAL:
10870
11010
case GGML_OP_IM2COL:
11011
+ case GGML_OP_IM2COL_3D:
10871
11012
case GGML_OP_TIMESTEP_EMBEDDING:
10872
11013
case GGML_OP_CONV_TRANSPOSE_1D:
10873
11014
case GGML_OP_POOL_2D:
@@ -12150,6 +12291,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
12150
12291
case GGML_OP_ARGMAX:
12151
12292
case GGML_OP_COUNT_EQUAL:
12152
12293
case GGML_OP_IM2COL:
12294
+ case GGML_OP_IM2COL_3D:
12153
12295
case GGML_OP_TIMESTEP_EMBEDDING:
12154
12296
case GGML_OP_CONV_2D_DW:
12155
12297
case GGML_OP_POOL_2D:
@@ -12725,6 +12867,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
12725
12867
12726
12868
const bool is_2D = tensor->op_params[6] == 1;
12727
12869
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
12870
+ } else if (tensor->op == GGML_OP_IM2COL_3D) {
12871
+ const int32_t s0 = tensor->op_params[0];
12872
+ const int32_t s1 = tensor->op_params[1];
12873
+ const int32_t s1 = tensor->op_params[2];
12874
+ const int32_t p0 = tensor->op_params[3];
12875
+ const int32_t p1 = tensor->op_params[4];
12876
+ const int32_t p1 = tensor->op_params[5];
12877
+ const int32_t d0 = tensor->op_params[6];
12878
+ const int32_t d1 = tensor->op_params[7];
12879
+ const int32_t d1 = tensor->op_params[8];
12880
+ const int32_t IC = tensor->op_params[9];
12881
+
12882
+ tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
12728
12883
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
12729
12884
const int32_t dim = tensor->op_params[0];
12730
12885
const int32_t max_period = tensor->op_params[1];
0 commit comments