Skip to content

Commit 1b4c19b

Browse files
committed
Add tests for conv_transpose_1d_gemm
Signed-off-by: Salvatore Mesoraca <[email protected]>
1 parent 3de2490 commit 1b4c19b

File tree

4 files changed

+1003
-0
lines changed

4 files changed

+1003
-0
lines changed

tests/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
350350
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
351351
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
352352

353+
# test-conv-transpose-1d-gemm
354+
355+
set(TEST_TARGET test-conv-transpose-1d-gemm)
356+
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
357+
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
358+
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
353359

354360
#
355361
# test-dup

tests/test-backend-ops.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,6 +2135,43 @@ struct test_conv_transpose_1d : public test_case {
21352135
}
21362136
};
21372137

2138+
struct test_conv_transpose_1d_gemm : public test_case {
2139+
const std::array<int64_t, 4> ne_input;
2140+
const std::array<int64_t, 4> ne_kernel;
2141+
2142+
const int s0; // stride
2143+
const int p0; // padding
2144+
const int d0; // dilation
2145+
2146+
ggml_type input_type;
2147+
ggml_type kernel_type;
2148+
2149+
std::string vars() override {
2150+
return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
2151+
}
2152+
2153+
test_conv_transpose_1d_gemm(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
2154+
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
2155+
int s0 = 1, int p0 = 0, int d0 = 1,
2156+
ggml_type input_type = GGML_TYPE_F32,
2157+
ggml_type kernel_type = GGML_TYPE_F16)
2158+
: ne_input(ne_input)
2159+
, ne_kernel(ne_kernel)
2160+
, s0(s0)
2161+
, p0(p0)
2162+
, d0(d0)
2163+
, input_type(input_type)
2164+
, kernel_type(kernel_type)
2165+
{}
2166+
2167+
ggml_tensor * build_graph(ggml_context * ctx) override {
2168+
ggml_tensor * input = ggml_new_tensor(ctx, input_type, 4, ne_input.data());
2169+
ggml_tensor * kernel = ggml_new_tensor(ctx, kernel_type, 4, ne_kernel.data());
2170+
ggml_tensor * out = ggml_conv_transpose_1d_gemm(ctx, kernel, input, s0, p0, d0);
2171+
return out;
2172+
}
2173+
};
2174+
21382175
// GGML_OP_IM2COL
21392176
struct test_im2col : public test_case {
21402177
const ggml_type type_input;
@@ -3200,6 +3237,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
32003237
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
32013238
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
32023239

3240+
test_cases.emplace_back(new test_conv_transpose_1d_gemm());
3241+
for (int64_t s0 = 1; s0 < 4; ++s0) {
3242+
for (int64_t p0 = 0; p0 < 2; ++p0) {
3243+
for (int64_t d0 = 1; d0 < 4; ++d0) {
3244+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {2,3,2,1}, s0, p0, d0));
3245+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,2,2,1}, s0, p0, d0));
3246+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,1,2,1}, s0, p0, d0));
3247+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({2,1,1,1}, {3,1,1,1}, s0, p0, d0));
3248+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {2,3,2,1},
3249+
s0, p0, d0, GGML_TYPE_F16));
3250+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,2,2,1},
3251+
s0, p0, d0, GGML_TYPE_F16));
3252+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,1,2,1},
3253+
s0, p0, d0, GGML_TYPE_F16));
3254+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({2,1,1,1}, {3,1,1,1},
3255+
s0, p0, d0, GGML_TYPE_F16));
3256+
}
3257+
}
3258+
}
32033259

32043260
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1}));
32053261
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}));

0 commit comments

Comments
 (0)