@@ -2135,6 +2135,43 @@ struct test_conv_transpose_1d : public test_case {
21352135 }
21362136};
21372137
2138+ struct test_conv_transpose_1d_gemm : public test_case {
2139+ const std::array<int64_t , 4 > ne_input;
2140+ const std::array<int64_t , 4 > ne_kernel;
2141+
2142+ const int s0; // stride
2143+ const int p0; // padding
2144+ const int d0; // dilation
2145+
2146+ ggml_type input_type;
2147+ ggml_type kernel_type;
2148+
2149+ std::string vars () override {
2150+ return VARS_TO_STR5 (ne_input, ne_kernel, s0, p0, d0);
2151+ }
2152+
2153+ test_conv_transpose_1d_gemm (std::array<int64_t , 4 > ne_input = {197 , 32 , 1 , 1 }, // [input_width, input_height, input_channels, 1]
2154+ std::array<int64_t , 4 > ne_kernel = {16 , 32 , 32 , 1 }, // [kernel_width, kernel_height, input_channels, 1]
2155+ int s0 = 1 , int p0 = 0 , int d0 = 1 ,
2156+ ggml_type input_type = GGML_TYPE_F32,
2157+ ggml_type kernel_type = GGML_TYPE_F16)
2158+ : ne_input(ne_input)
2159+ , ne_kernel(ne_kernel)
2160+ , s0(s0)
2161+ , p0(p0)
2162+ , d0(d0)
2163+ , input_type(input_type)
2164+ , kernel_type(kernel_type)
2165+ {}
2166+
2167+ ggml_tensor * build_graph (ggml_context * ctx) override {
2168+ ggml_tensor * input = ggml_new_tensor (ctx, input_type, 4 , ne_input.data ());
2169+ ggml_tensor * kernel = ggml_new_tensor (ctx, kernel_type, 4 , ne_kernel.data ());
2170+ ggml_tensor * out = ggml_conv_transpose_1d_gemm (ctx, kernel, input, s0, p0, d0);
2171+ return out;
2172+ }
2173+ };
2174+
21382175// GGML_OP_IM2COL
21392176struct test_im2col : public test_case {
21402177 const ggml_type type_input;
@@ -3200,6 +3237,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
32003237 test_cases.emplace_back (new test_conv_transpose_1d ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 }, 1 , 0 , 1 ));
32013238 test_cases.emplace_back (new test_conv_transpose_1d ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 }, 1 , 0 , 1 ));
32023239
3240+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ());
3241+ for (int64_t s0 = 1 ; s0 < 4 ; ++s0) {
3242+ for (int64_t p0 = 0 ; p0 < 2 ; ++p0) {
3243+ for (int64_t d0 = 1 ; d0 < 4 ; ++d0) {
3244+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {2 ,3 ,2 ,1 }, s0, p0, d0));
3245+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,2 ,2 ,1 }, s0, p0, d0));
3246+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 }, s0, p0, d0));
3247+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 }, s0, p0, d0));
3248+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {2 ,3 ,2 ,1 },
3249+ s0, p0, d0, GGML_TYPE_F16));
3250+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,2 ,2 ,1 },
3251+ s0, p0, d0, GGML_TYPE_F16));
3252+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 },
3253+ s0, p0, d0, GGML_TYPE_F16));
3254+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 },
3255+ s0, p0, d0, GGML_TYPE_F16));
3256+ }
3257+ }
3258+ }
32033259
32043260 test_cases.emplace_back (new test_repeat (GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {1 , 1 , 1 , 1 }));
32053261 test_cases.emplace_back (new test_repeat (GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {2 , 1 , 1 , 1 }));
0 commit comments