Skip to content

Commit dbc6cde

Browse files
authored
Merge pull request #1 from ggml-org/0cc4m/vulkan-op-opt-step-sgd
Finish Vulkan OPT_STEP_SGD op implementation
2 parents 50e83ea + 2ec70c9 commit dbc6cde

File tree

5 files changed

+50
-29
lines changed

5 files changed

+50
-29
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ struct vk_device_struct {
507507
vk_pipeline pipeline_rwkv_wkv6_f32;
508508
vk_pipeline pipeline_rwkv_wkv7_f32;
509509
vk_pipeline pipeline_opt_step_adamw_f32;
510+
vk_pipeline pipeline_opt_step_sgd_f32;
510511
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
511512
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
512513
vk_pipeline pipeline_conv2d_dw_whcn_f32;
@@ -3085,6 +3086,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
30853086

30863087
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30873088

3089+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3090+
30883091
// conv2d
30893092
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
30903093
uint32_t conv2d_WG_SIZE = 256;
@@ -7120,7 +7123,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
71207123
return nullptr;
71217124
case GGML_OP_OPT_STEP_SGD:
71227125
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7123-
// TODO
7126+
return ctx->device->pipeline_opt_step_sgd_f32;
71247127
}
71257128
return nullptr;
71267129
case GGML_OP_LEAKY_RELU:
@@ -7599,6 +7602,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75997602
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
76007603
ggml_vk_sync_buffers(subctx);
76017604
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7605+
} else if (op == GGML_OP_OPT_STEP_SGD) {
7606+
// OPT_STEP_SGD works on src0, it does not need dst
7607+
ggml_vk_sync_buffers(subctx);
7608+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
76027609
} else if (use_src2) {
76037610
ggml_vk_sync_buffers(subctx);
76047611
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@@ -7937,18 +7944,10 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
79377944
);
79387945
}
79397946

7940-
static void ggml_vk_op_f32_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
7941-
GGML_ASSERT(0 && "SGD vulkan unimplemented"); // TODO
7942-
}
7943-
7944-
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
7947+
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
79457948
const size_t n = ggml_nelements(dst->src[0]);
79467949

7947-
ggml_vk_op_f32_opt_step_sgd(
7948-
ctx, subctx, dst,
7949-
{ (uint32_t)n, 0, 0.0f, 0.0f },
7950-
dryrun
7951-
);
7950+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
79527951
}
79537952

79547953
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9489,6 +9488,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
94899488
case GGML_OP_LEAKY_RELU:
94909489
case GGML_OP_FLASH_ATTN_EXT:
94919490
case GGML_OP_OPT_STEP_ADAMW:
9491+
case GGML_OP_OPT_STEP_SGD:
94929492
break;
94939493
default:
94949494
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -9553,6 +9553,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
95539553
case GGML_OP_CONV_2D:
95549554
case GGML_OP_CONV_2D_DW:
95559555
case GGML_OP_LEAKY_RELU:
9556+
case GGML_OP_OPT_STEP_SGD:
95569557
{
95579558
// These operations all go through ggml_vk_op_f32, so short-circuit and
95589559
// do the only thing needed for the dryrun.
@@ -9800,8 +9801,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
98009801
break;
98019802

98029803
case GGML_OP_OPT_STEP_SGD:
9803-
return false; // TODO
9804-
ggml_vk_opt_step_sgd(ctx, compute_ctx, node, dryrun);
9804+
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
98059805

98069806
break;
98079807
default:
@@ -9905,10 +9905,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
99059905
case GGML_OP_REPEAT:
99069906
case GGML_OP_REPEAT_BACK:
99079907
case GGML_OP_OPT_STEP_ADAMW:
9908+
case GGML_OP_OPT_STEP_SGD:
99089909
buf = tensor->buffer;
99099910
break;
9910-
case GGML_OP_OPT_STEP_SGD:
9911-
return false;
99129911
case GGML_OP_UNARY:
99139912
switch (ggml_get_unary_op(tensor)) {
99149913
case GGML_UNARY_OP_SILU:
@@ -11036,6 +11035,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1103611035
case GGML_OP_SIN:
1103711036
case GGML_OP_COS:
1103811037
case GGML_OP_CLAMP:
11038+
case GGML_OP_LEAKY_RELU:
11039+
case GGML_OP_OPT_STEP_ADAMW:
11040+
case GGML_OP_OPT_STEP_SGD:
1103911041
return op->src[0]->type == GGML_TYPE_F32;
1104011042
case GGML_OP_UPSCALE:
1104111043
case GGML_OP_ACC:
@@ -11057,11 +11059,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1105711059
case GGML_OP_POOL_2D:
1105811060
case GGML_OP_RWKV_WKV6:
1105911061
case GGML_OP_RWKV_WKV7:
11060-
case GGML_OP_LEAKY_RELU:
11061-
case GGML_OP_OPT_STEP_ADAMW:
1106211062
return true;
11063-
case GGML_OP_OPT_STEP_SGD:
11064-
return false;
1106511063
case GGML_OP_CONV_TRANSPOSE_1D:
1106611064
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1106711065
case GGML_OP_CONV_2D:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) buffer X {A_TYPE data_x[];};
11+
layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
12+
layout (binding = 2) readonly buffer P {float data_params[2];};
13+
14+
void main() {
15+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
16+
17+
if (i >= p.KX) {
18+
return;
19+
}
20+
21+
const float alpha = data_params[0];
22+
const float keep = data_params[1];
23+
24+
data_x[i] = data_x[i] * keep - alpha * data_grad[i];
25+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,7 @@ void process_shaders() {
654654
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
655655

656656
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
657+
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
657658

658659
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
659660
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});

ggml/src/ggml.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,8 +1006,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10061006
"CROSS_ENTROPY_LOSS",
10071007
"CROSS_ENTROPY_LOSS_BACK",
10081008
"OPT_STEP_ADAMW",
1009-
"GLU",
10101009
"OPT_STEP_SGD",
1010+
1011+
"GLU",
10111012
};
10121013

10131014
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
@@ -1106,8 +1107,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11061107
"cross_entropy_loss(x,y)",
11071108
"cross_entropy_loss_back(x,y)",
11081109
"adamw(x)",
1109-
"glu(x)",
11101110
"sgd(x)",
1111+
1112+
"glu(x)",
11111113
};
11121114

11131115
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

tests/test-backend-ops.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5110,7 +5110,7 @@ static const ggml_type other_types[] = {
51105110
};
51115111

51125112
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
5113-
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sgd = true) {
5113+
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51145114
std::vector<std::unique_ptr<test_case>> test_cases;
51155115
std::default_random_engine rng(0);
51165116

@@ -5912,8 +5912,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sg
59125912
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
59135913

59145914
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
5915-
if (test_sgd)
5916-
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));
5915+
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));
59175916

59185917
#if 0
59195918
// these tests are disabled to save execution time, sbut they can be handy for debugging
@@ -6051,10 +6050,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
60516050
}
60526051
};
60536052

6054-
char const* name = ggml_backend_name(backend);
6055-
bool const vulkan = strstr(name, "ulkan");
6056-
bool const sgd = !vulkan;
6057-
60586053
if (mode == MODE_TEST) {
60596054
auto test_cases = make_test_cases_eval();
60606055
filter_test_cases(test_cases, params_filter);
@@ -6080,7 +6075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
60806075
}
60816076

60826077
if (mode == MODE_GRAD) {
6083-
auto test_cases = make_test_cases_eval(sgd);
6078+
auto test_cases = make_test_cases_eval();
60846079
filter_test_cases(test_cases, params_filter);
60856080
size_t n_ok = 0;
60866081
for (auto & test : test_cases) {

0 commit comments

Comments
 (0)