Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2031,11 +2031,33 @@ extern "C" {
int dilation0,
int dilation1);

GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
GGML_API struct ggml_tensor * ggml_conv_2d_circular(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int stride);
int s0,
int s1,
int p0,
int p1,
int d0,
int d1);

GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct_circular(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int stride0,
int stride1,
int pad0,
int pad1,
int dilation0,
int dilation1);

GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int stride);

GGML_API struct ggml_tensor * ggml_conv_2d_direct(
struct ggml_context * ctx,
Expand All @@ -2048,6 +2070,17 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1

GGML_API struct ggml_tensor * ggml_conv_2d_direct_circular(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1);

GGML_API struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx,
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
Expand Down Expand Up @@ -2168,6 +2201,19 @@ extern "C" {
int rp3
);

GGML_API struct ggml_tensor * ggml_pad_circular(
struct ggml_context * ctx,
struct ggml_tensor * a,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3
);

// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,
Expand Down
112 changes: 84 additions & 28 deletions src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6660,6 +6660,12 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
ggml_compute_forward_mul_mat(params, &dst);
}

// ggml_wrap_coord

static inline int64_t ggml_wrap_coord(int64_t coord, int64_t size) {
return (coord + size) % size; // adding size avoids negative number weirdness
}

// ggml_compute_forward_conv_2d

static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
Expand All @@ -6680,6 +6686,7 @@ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params
const int32_t pad_y = dst->op_params[3];
const int32_t dilation_x = dst->op_params[4];
const int32_t dilation_y = dst->op_params[5];
const bool circular = ggml_get_op_params_i32(dst, 6) != 0;

const int64_t c_in = src->ne[2];
const int64_t c_out = kernel->ne[3];
Expand Down Expand Up @@ -6736,13 +6743,19 @@ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params

int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;

float src_val;
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
src_val = 0.0f;
} else {
float src_val = 0.0f;
if (circular) {
const int64_t sy_wrapped = ggml_wrap_coord(sy, src_h);
const int64_t sx_wrapped = ggml_wrap_coord(sx, src_w);
const float * src_ptr = (const float *)((const char *)src_base + sx_wrapped * src->nb[0] + sy_wrapped * src->nb[1] + ic * src->nb[2]);
src_val = *src_ptr;
} else if (sy >= 0 && sy < src_h && sx >= 0 && sx < src_w) {
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
src_val = *src_ptr;
}
else {
src_val = 1.0f;
}

char * element_ptr = dst_row + dst_idx * traits->type_size;
if (kernel_type == GGML_TYPE_F32) {
Expand Down Expand Up @@ -7052,6 +7065,7 @@ struct ggml_conv_2d_dw_params {
int pad_y;
int dilation_x;
int dilation_y;
int circular;
};

static void ggml_compute_forward_conv_2d_dw_cwhn(
Expand All @@ -7063,6 +7077,7 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(

const int64_t c = p.channels;
const float * knl_data = (const float *)kernel->data;
const bool circular = p.circular != 0;

const int64_t rows_total = p.dst_h * p.batch;
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
Expand Down Expand Up @@ -7090,13 +7105,17 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
const int64_t src_y = src_y_base + knl_y * p.dilation_y;
if (src_y < 0 || src_y >= p.src_h) {
int64_t src_y = src_y_base + knl_y * p.dilation_y;
if (circular) {
src_y = ggml_wrap_coord(src_y, p.src_h);
} else if (src_y < 0 || src_y >= p.src_h) {
continue;
}
for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
const int64_t src_x = src_x_base + knl_x * p.dilation_x;
if (src_x < 0 || src_x >= p.src_w) {
int64_t src_x = src_x_base + knl_x * p.dilation_x;
if (circular) {
src_x = ggml_wrap_coord(src_x, p.src_w);
} else if (src_x < 0 || src_x >= p.src_w) {
continue;
}
GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
Expand All @@ -7111,13 +7130,17 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
float sum = 0.0f;
for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
const int64_t src_y = src_y_base + knl_y * p.dilation_y;
if (src_y < 0 || src_y >= p.src_h) {
int64_t src_y = src_y_base + knl_y * p.dilation_y;
if (circular) {
src_y = ggml_wrap_coord(src_y, p.src_h);
} else if (src_y < 0 || src_y >= p.src_h) {
continue;
}
for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
const int64_t src_x = src_x_base + knl_x * p.dilation_x;
if (src_x < 0 || src_x >= p.src_w) {
int64_t src_x = src_x_base + knl_x * p.dilation_x;
if (circular) {
src_x = ggml_wrap_coord(src_x, p.src_w);
} else if (src_x < 0 || src_x >= p.src_w) {
continue;
}
sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
Expand All @@ -7138,6 +7161,7 @@ static void ggml_compute_forward_conv_2d_dw_whcn(
const ggml_conv_2d_dw_params & p) {

const int64_t n = p.channels * p.batch;
const bool circular = p.circular != 0;
const int64_t per_thread = (n + params->nth - 1) / params->nth;
const int64_t start = params->ith * per_thread;
const int64_t end = MIN(start + per_thread, n);
Expand All @@ -7152,13 +7176,17 @@ static void ggml_compute_forward_conv_2d_dw_whcn(

float sum = 0.0f;
for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y < 0 || src_y >= p.src_h) {
int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (circular) {
src_y = ggml_wrap_coord(src_y, p.src_h);
} else if (src_y < 0 || src_y >= p.src_h) {
continue;
}
for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x < 0 || src_x >= p.src_w) {
int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (circular) {
src_x = ggml_wrap_coord(src_x, p.src_w);
} else if (src_x < 0 || src_x >= p.src_w) {
continue;
}
sum += knl_data[knl_y * p.knl_w + knl_x]
Expand Down Expand Up @@ -7192,6 +7220,7 @@ void ggml_compute_forward_conv_2d_dw(
p.pad_y = dst->op_params[3];
p.dilation_x = dst->op_params[4];
p.dilation_y = dst->op_params[5];
p.circular = ggml_get_op_params_i32(dst, 6);

GGML_ASSERT(kernel->ne[3] == p.channels);
GGML_ASSERT(dst->ne[3] == p.batch);
Expand Down Expand Up @@ -7612,24 +7641,51 @@ static void ggml_compute_forward_pad_f32(
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
const int32_t mode = ggml_get_op_params_i32(dst, 8);
const bool circular = mode == GGML_PAD_MODE_CIRCULAR;


// TODO: optimize

for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) \
&& (i1 >= lp1 && i1 < ne1 - rp1) \
&& (i2 >= lp2 && i2 < ne2 - rp2) \
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
if (!circular) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) \
&& (i1 >= lp1 && i1 < ne1 - rp1) \
&& (i2 >= lp2 && i2 < ne2 - rp2) \
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
}
}
}
}
}
} else {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
const int64_t src_i0 = ggml_wrap_coord(i0 - lp0, ne00);
const int64_t src_i1 = ggml_wrap_coord(i1 - lp1, ne01);
const int64_t src_i2 = ggml_wrap_coord(i2 - lp2, ne02);
const int64_t src_i3 = ggml_wrap_coord(i3 - lp3, ne03);

const int64_t src_idx =
src_i3*nb03 +
src_i2*nb02 +
src_i1*nb01 +
src_i0*nb00;

const float * src_ptr = (const float *)((char *) src0->data + src_idx);
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
}
}
}
Expand Down
48 changes: 36 additions & 12 deletions src/ggml-cuda/conv2d-dw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,35 @@ struct conv_params {
int padding_x, padding_y;
int dilation_x, dilation_y;
int channels, batches;
int circular;
};

struct kernel_bounds {
int y_min, y_max;
int x_min, x_max;
};

__device__ __forceinline__ int wrap_coord(int coord, int size) {
return (coord+size) % size; // +size to fix negative numbers giving incorrect mod
}

__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
kernel_bounds bounds;
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.y_max =
min(params.kernel_h,
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
bounds.x_max =
min(params.kernel_w,
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
if (params.circular) {
bounds.y_min = 0;
bounds.y_max = params.kernel_h;
bounds.x_min = 0;
bounds.x_max = params.kernel_w;
} else {
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.y_max =
min(params.kernel_h,
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
bounds.x_max =
min(params.kernel_w,
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
}
return bounds;
}

Expand Down Expand Up @@ -83,7 +95,7 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
const int in_w, const int in_h, const int out_w, const int out_h,
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
const int channels, const int batches) {
const int channels, const int batches, const int circular) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = batches * channels * out_h * out_w;

Expand All @@ -92,19 +104,30 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
}

conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches, circular };

int batch_idx, channel_idx, out_y_idx, out_x_idx;
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);

T accumulator = 0;
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);


for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
if (params.circular) {
in_y_idx = wrap_coord(in_y_idx, params.in_h);
} else if (in_y_idx < 0 || in_y_idx >= params.in_h) {
continue;
}

for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
if (params.circular) {
in_x_idx = wrap_coord(in_x_idx, params.in_w);
} else if (in_x_idx < 0 || in_x_idx >= params.in_w) {
continue;
}

const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
Expand Down Expand Up @@ -132,6 +155,7 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const int padding_y = p[3];
const int dilation_x = p[4];
const int dilation_y = p[5];
const int circular = p[6];

const int in_w = input->ne[0];
const int in_h = input->ne[1];
Expand All @@ -150,11 +174,11 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
if (ggml_is_contiguous(input)) {
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
dilation_x, dilation_y, channels, batches, circular);
} else if (ggml_is_contiguous_channels(input)) {
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
dilation_x, dilation_y, channels, batches, circular);
} else {
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
}
Expand Down
Loading