fix general_conv padding

This commit is contained in:
a-turker 2025-04-12 11:56:10 +03:00 committed by Awni Hannun
parent 0cae0bdac8
commit 7942191a64
5 changed files with 369 additions and 261 deletions

View File

@ -22,7 +22,8 @@ void slow_conv_1D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -60,7 +61,8 @@ void slow_conv_1D(
out_stride_O = out.strides()[2], out_stride_O = out.strides()[2],
flip, flip,
padding = padding[0], padding_lo = padding_lo[0],
padding_hi = padding_hi[0],
wt_stride = wt_strides[0], wt_stride = wt_strides[0],
wt_dilation = wt_dilation[0], wt_dilation = wt_dilation[0],
in_dilation = in_dilation[0]]() mutable { in_dilation = in_dilation[0]]() mutable {
@ -77,7 +79,7 @@ void slow_conv_1D(
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh; int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_stride - padding + wh_flip * wt_dilation; int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
auto ih_div = std::div(ih, in_dilation); auto ih_div = std::div(ih, in_dilation);
@ -109,7 +111,8 @@ void slow_conv_2D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -120,230 +123,235 @@ void slow_conv_2D(
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.dispatch([st_wt_ptr = wt.data<T>(), encoder.dispatch(
st_in_ptr = in.data<T>(), [st_wt_ptr = wt.data<T>(),
st_out_ptr = out.data<T>(), st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
N = in.shape( N = in.shape(0), // Batch size, should be the same as out.shape(0)
0), // Batch size, should be the same as out.shape(0) iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iH = 1 + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim C = in.shape(3), // In channels
iW = 1 + oH = out.shape(1), // Output spatial dim
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim oW = out.shape(2), // Output spatial dim
C = in.shape(3), // In channels O = wt.shape(0), // Out channels
oH = out.shape(1), // Output spatial dim wH = wt.shape(1), // Weight spatial dim
oW = out.shape(2), // Output spatial dim wW = wt.shape(2), // Weight spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
groups = in.shape(3) / wt.shape(3), groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3), C_per_group = wt.shape(3),
in_stride_N = in.strides()[0], in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1], in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2], in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3], in_stride_C = in.strides()[3],
wt_stride_O = wt.strides()[0], wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1], wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2], wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3], wt_stride_C = wt.strides()[3],
out_stride_N = out.strides()[0], out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1], out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2], out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3], out_stride_O = out.strides()[3],
padding, padding_lo,
wt_strides, padding_hi,
wt_dilation, wt_strides,
in_dilation, wt_dilation,
flip]() mutable { in_dilation,
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
const int O_per_group = O / groups; const int O_per_group = O / groups;
auto pt_conv_no_checks = [&](const T* in_ptr, auto pt_conv_no_checks =
const T* wt_ptr, [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
T* out_ptr, out_ptr += oh * out_stride_H + ow * out_stride_W;
int oh, int ih_base = oh * wt_strides[0] - padding_lo[0];
int ow) { int iw_base = ow * wt_strides[1] - padding_lo[1];
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = 0; wh < wH; ++wh) { for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) { for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh; int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww; int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; const T* wt_ptr_pt =
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { for (int c = g * C_per_group; c < (g + 1) * C_per_group;
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) * ++c) {
static_cast<float>( r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
wt_ptr_pt[(c % C_per_group) * wt_stride_C]); static_cast<float>(
} // c wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // ww } // c
} // wh } // ww
} // wh
out_ptr[0] = static_cast<T>(r); out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g } // g
}; };
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; int f_out_jump_h =
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w =
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h); std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w); std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) { for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h; int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
int wh_base = 0; int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) { while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++; wh_base++;
ih_loop += jump_h; ih_loop += jump_h;
} }
base_h[i] = wh_base; base_h[i] = wh_base;
} }
for (int j = 0; j < f_out_jump_w; ++j) { for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w; int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
int ww_base = 0; int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) { while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++; ww_base++;
iw_loop += jump_w; iw_loop += jump_w;
} }
base_w[j] = ww_base; base_w[j] = ww_base;
} }
auto pt_conv_all_checks = auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W; out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0]; int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding[1]; int iw_base = ow * wt_strides[1] - padding_lo[1];
int wh_base = base_h[oh % f_out_jump_h]; int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w]; int ww_base = base_w[ow % f_out_jump_w];
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh; int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww; int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt = const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W; wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt = const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; iw_dil * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) { ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) * r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>( static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]); wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c } // c
} // ih, iw check } // ih, iw check
} // ww } // ww
} // wh } // wh
out_ptr[0] = static_cast<T>(r); out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g } // g
}; };
int oH_border_0 = 0; int oH_border_0 = 0;
int oH_border_1 = int oH_border_1 = is_idil_one
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
int oH_border_2 = std::max( : oH;
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); int oH_border_2 = std::max(
int oH_border_3 = oH; oH_border_1,
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0; int oW_border_0 = 0;
int oW_border_1 = int oW_border_1 = is_idil_one
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
int oW_border_2 = std::max( : oW;
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); int oW_border_2 = std::max(
int oW_border_3 = oW; oW_border_1,
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds // Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) { for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) { for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
// Case 2: oh in bounds // Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) { for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds // Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) { for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
// Case b: ow in bounds // Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) { for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
// Case c: ow might put us out of bounds // Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) { for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
// Case 3: oh might put us out of bounds // Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) { for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) { for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
st_in_ptr += in_stride_N; st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N; st_out_ptr += out_stride_N;
} // n } // n
}); });
} }
template <typename T> template <typename T>
@ -351,7 +359,8 @@ void slow_conv_3D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -400,7 +409,8 @@ void slow_conv_3D(
out_stride_H = out.strides()[2], out_stride_H = out.strides()[2],
out_stride_W = out.strides()[3], out_stride_W = out.strides()[3],
out_stride_O = out.strides()[4], out_stride_O = out.strides()[4],
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -415,9 +425,9 @@ void slow_conv_3D(
int oh, int oh,
int ow) { int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0]; int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding[1]; int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding[2]; int iw_base = ow * wt_strides[2] - padding_lo[2];
for (int o = 0; o < O; ++o) { for (int o = 0; o < O; ++o) {
float r = 0.; float r = 0.;
@ -478,7 +488,7 @@ void slow_conv_3D(
std::vector<int> base_w(f_out_jump_w); std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) { for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d; int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
int wd_base = 0; int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) { while (wd_base < wD && id_loop % in_dilation[0] != 0) {
@ -490,7 +500,7 @@ void slow_conv_3D(
} }
for (int i = 0; i < f_out_jump_h; ++i) { for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h; int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
int wh_base = 0; int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) { while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
@ -502,7 +512,7 @@ void slow_conv_3D(
} }
for (int j = 0; j < f_out_jump_w; ++j) { for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w; int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
int ww_base = 0; int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) { while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
@ -521,9 +531,9 @@ void slow_conv_3D(
int ow) { int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0]; int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding[1]; int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding[2]; int iw_base = ow * wt_strides[2] - padding_lo[2];
int wd_base = base_d[od % f_out_jump_d]; int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h]; int wh_base = base_h[oh % f_out_jump_h];
@ -573,24 +583,30 @@ void slow_conv_3D(
}; };
int oD_border_0 = 0; int oD_border_0 = 0;
int oD_border_1 = int oD_border_1 = is_idil_one
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oD;
int oD_border_2 = std::max( int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); oD_border_1,
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD; int oD_border_3 = oD;
int oH_border_0 = 0; int oH_border_0 = 0;
int oH_border_1 = int oH_border_1 = is_idil_one
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oH;
int oH_border_2 = std::max( int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); oH_border_1,
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH; int oH_border_3 = oH;
int oW_border_0 = 0; int oW_border_0 = 0;
int oW_border_1 = int oW_border_1 = is_idil_one
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
: oW;
int oW_border_2 = std::max( int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); oW_border_1,
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW; int oW_border_3 = oW;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
Stream stream) { Stream stream) {
@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
// Pad input // Pad input
Shape padded_shape = {N, iH + 2 * padding[0], C}; Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = padding[0] * in_padded.strides()[1]; size_t data_offset = padding_lo[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
Stream stream) { Stream stream) {
@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
// Pad input // Pad input
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const bool flip, const bool flip,
@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
Shape padded_shape(in.shape().size()); Shape padded_shape(in.shape().size());
padded_shape.front() = N; padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) { for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i]; padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
} }
padded_shape.back() = C; padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = 0; size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) { for (size_t i = 0; i < padding_lo.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1]; data_offset += padding_lo[i] * in_padded.strides()[i + 1];
} }
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@ -1261,7 +1297,8 @@ void conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -1270,22 +1307,40 @@ void conv_1D_cpu(
const int groups = in.shape().back() / wt.shape().back(); const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu( return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, stream); in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
} }
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_1D( return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
void conv_2D_cpu( void conv_2D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -1295,18 +1350,35 @@ void conv_2D_cpu(
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) { in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_2D( return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
void conv_3D_cpu( void conv_3D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@ -1317,11 +1389,28 @@ void conv_3D_cpu(
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) { groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_3D( return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
} // namespace } // namespace
@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,

View File

@ -944,6 +944,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy; wt = arr_copy;
} }
auto padding_ = padding_lo_;
// 3D conv // 3D conv
if (out.ndim() == 5) { if (out.ndim() == 5) {
conv_3D_gpu( conv_3D_gpu(

View File

@ -3974,6 +3974,7 @@ array conv_general(
to_stream(s), to_stream(s),
stride, stride,
padding_lo, padding_lo,
padding_hi,
kernel_dilation, kernel_dilation,
input_dilation, input_dilation,
groups, groups,

View File

@ -1055,7 +1055,8 @@ array conv_weight_backward_patches(
const array& wt, const array& wt,
const array& cotan, const array& cotan,
const std::vector<int>& kernel_strides, const std::vector<int>& kernel_strides,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
StreamOrDevice s) { StreamOrDevice s) {
// Resolve Padded input shapes and strides // Resolve Padded input shapes and strides
Shape padding_starts(in.ndim(), 0); Shape padding_starts(in.ndim(), 0);
@ -1064,9 +1065,9 @@ array conv_weight_backward_patches(
// padded shape // padded shape
for (int i = 1; i < in.ndim() - 1; i++) { for (int i = 1; i < in.ndim() - 1; i++) {
in_padded_shape[i] += 2 * padding[i - 1]; in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1];
padding_ends[i] += padding[i - 1]; padding_ends[i] += padding_lo[i - 1];
padding_starts[i] += padding[i - 1]; padding_starts[i] += padding_lo[i - 1];
} }
// padded strides (contiguous) // padded strides (contiguous)
@ -1078,9 +1079,16 @@ array conv_weight_backward_patches(
// Pad input // Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0); std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1); std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_(padding.begin(), padding.end()); Shape padding_lo_(padding_lo.begin(), padding_lo.end());
auto in_padded = pad( Shape padding_hi_(padding_hi.begin(), padding_hi.end());
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); auto in_padded =
pad(in,
padded_axes,
padding_lo_,
padding_hi_,
array(0, in.dtype()),
"constant",
s);
// Resolve strided patches // Resolve strided patches
@ -1147,16 +1155,16 @@ std::vector<array> Convolution::vjp(
for (int a : argnums) { for (int a : argnums) {
// Grads for input // Grads for input
if (a == 0) { if (a == 0) {
std::vector<int> padding_lo = padding_; std::vector<int> padding_lo = padding_lo_;
std::vector<int> padding_hi = padding_; std::vector<int> padding_hi = padding_hi_;
for (int i = 0; i < padding_lo.size(); ++i) { for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_[i] - 1; padding_lo[i] = wt_size - padding_lo_[i] - 1;
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding_[i]; padding_hi[i] = in_size - out_size + padding_hi_[i];
} }
// Check for negative padding // Check for negative padding
@ -1226,17 +1234,17 @@ std::vector<array> Convolution::vjp(
if (no_dilation && !flip_ && groups_ == 1) { if (no_dilation && !flip_ && groups_ == 1) {
auto grad = conv_weight_backward_patches( auto grad = conv_weight_backward_patches(
in, wt, cotan, kernel_strides_, padding_, stream()); in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
grads.push_back(grad); grads.push_back(grad);
} else { } else {
std::vector<int> padding_lo = padding_; std::vector<int> padding_lo = padding_lo_;
std::vector<int> padding_hi = padding_; std::vector<int> padding_hi = padding_hi_;
for (int i = 0; i < padding_hi.size(); ++i) { for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; padding_hi[i] = out_size - in_size + wt_size - padding_hi_[i] - 1;
} }
auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1); auto in_trans = group_transpose(in, -1, 0, -1);
@ -1332,7 +1340,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
bool Convolution::is_equivalent(const Primitive& other) const { bool Convolution::is_equivalent(const Primitive& other) const {
const Convolution& c_other = static_cast<const Convolution&>(other); const Convolution& c_other = static_cast<const Convolution&>(other);
return padding_ == c_other.padding_ && return padding_lo_ == c_other.padding_lo_ &&
padding_hi_ == c_other.padding_hi_ &&
kernel_strides_ == c_other.kernel_strides_ && kernel_strides_ == c_other.kernel_strides_ &&
kernel_dilation_ == c_other.kernel_dilation_ && kernel_dilation_ == c_other.kernel_dilation_ &&
input_dilation_ == c_other.input_dilation_ && input_dilation_ == c_other.input_dilation_ &&

View File

@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive {
explicit Convolution( explicit Convolution(
Stream stream, Stream stream,
const std::vector<int>& kernel_strides, const std::vector<int>& kernel_strides,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& kernel_dilation, const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation, const std::vector<int>& input_dilation,
const int groups = 1, const int groups = 1,
const bool flip = false) const bool flip = false)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
padding_(padding), padding_lo_(padding_lo),
padding_hi_(padding_hi),
kernel_strides_(kernel_strides), kernel_strides_(kernel_strides),
kernel_dilation_(kernel_dilation), kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation), input_dilation_(input_dilation),
@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
auto state() const { auto state() const {
return std::make_tuple( return std::make_tuple(
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive {
} }
private: private:
std::vector<int> padding_; std::vector<int> padding_lo_;
std::vector<int> padding_hi_;
std::vector<int> kernel_strides_; std::vector<int> kernel_strides_;
std::vector<int> kernel_dilation_; std::vector<int> kernel_dilation_;
std::vector<int> input_dilation_; std::vector<int> input_dilation_;