mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix general_conv padding
This commit is contained in:
parent
0cae0bdac8
commit
7942191a64
@ -22,7 +22,8 @@ void slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -60,7 +61,8 @@ void slow_conv_1D(
|
||||
out_stride_O = out.strides()[2],
|
||||
|
||||
flip,
|
||||
padding = padding[0],
|
||||
padding_lo = padding_lo[0],
|
||||
padding_hi = padding_hi[0],
|
||||
wt_stride = wt_strides[0],
|
||||
wt_dilation = wt_dilation[0],
|
||||
in_dilation = in_dilation[0]]() mutable {
|
||||
@ -77,7 +79,7 @@ void slow_conv_1D(
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
|
||||
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
|
||||
|
||||
auto ih_div = std::div(ih, in_dilation);
|
||||
|
||||
@ -109,7 +111,8 @@ void slow_conv_2D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -120,16 +123,14 @@ void slow_conv_2D(
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
encoder.dispatch([st_wt_ptr = wt.data<T>(),
|
||||
encoder.dispatch(
|
||||
[st_wt_ptr = wt.data<T>(),
|
||||
st_in_ptr = in.data<T>(),
|
||||
st_out_ptr = out.data<T>(),
|
||||
|
||||
N = in.shape(
|
||||
0), // Batch size, should be the same as out.shape(0)
|
||||
iH = 1 +
|
||||
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||
iW = 1 +
|
||||
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||
N = in.shape(0), // Batch size, should be the same as out.shape(0)
|
||||
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||
C = in.shape(3), // In channels
|
||||
oH = out.shape(1), // Output spatial dim
|
||||
oW = out.shape(2), // Output spatial dim
|
||||
@ -155,7 +156,8 @@ void slow_conv_2D(
|
||||
out_stride_W = out.strides()[2],
|
||||
out_stride_O = out.strides()[3],
|
||||
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -163,14 +165,11 @@ void slow_conv_2D(
|
||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||
|
||||
const int O_per_group = O / groups;
|
||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int oh,
|
||||
int ow) {
|
||||
auto pt_conv_no_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
@ -183,10 +182,13 @@ void slow_conv_2D(
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
@ -212,14 +214,16 @@ void slow_conv_2D(
|
||||
int f_wgt_jump_w =
|
||||
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
|
||||
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
int f_out_jump_h =
|
||||
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_w =
|
||||
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
||||
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||
@ -231,7 +235,7 @@ void slow_conv_2D(
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
||||
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||
@ -246,8 +250,8 @@ void slow_conv_2D(
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
@ -270,8 +274,8 @@ void slow_conv_2D(
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
|
||||
iw_dil * in_stride_W;
|
||||
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
@ -292,17 +296,21 @@ void slow_conv_2D(
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
||||
int oH_border_1 = is_idil_one
|
||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||
: oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||
oH_border_1,
|
||||
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
||||
int oW_border_1 = is_idil_one
|
||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||
: oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||
oW_border_1,
|
||||
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
@ -351,7 +359,8 @@ void slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -400,7 +409,8 @@ void slow_conv_3D(
|
||||
out_stride_H = out.strides()[2],
|
||||
out_stride_W = out.strides()[3],
|
||||
out_stride_O = out.strides()[4],
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -415,9 +425,9 @@ void slow_conv_3D(
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
@ -478,7 +488,7 @@ void slow_conv_3D(
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
|
||||
|
||||
int wd_base = 0;
|
||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||
@ -490,7 +500,7 @@ void slow_conv_3D(
|
||||
}
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||
@ -502,7 +512,7 @@ void slow_conv_3D(
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||
@ -521,9 +531,9 @@ void slow_conv_3D(
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||
|
||||
int wd_base = base_d[od % f_out_jump_d];
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
@ -573,24 +583,30 @@ void slow_conv_3D(
|
||||
};
|
||||
|
||||
int oD_border_0 = 0;
|
||||
int oD_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||
int oD_border_1 = is_idil_one
|
||||
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||
: oD;
|
||||
int oD_border_2 = std::max(
|
||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
oD_border_1,
|
||||
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
int oD_border_3 = oD;
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||
int oH_border_1 = is_idil_one
|
||||
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||
: oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
oH_border_1,
|
||||
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||
int oW_border_1 = is_idil_one
|
||||
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
|
||||
: oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
oW_border_1,
|
||||
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
Stream stream) {
|
||||
@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
// Pad input
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = padding[0] * in_padded.strides()[1];
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
Stream stream) {
|
||||
@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
// Pad input
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
Shape padded_shape = {
|
||||
N,
|
||||
iH + padding_lo[0] + padding_hi[0],
|
||||
iW + padding_lo[1] + padding_hi[1],
|
||||
C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
|
||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset =
|
||||
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
||||
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||
padding_lo[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const bool flip,
|
||||
@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = 0;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||
for (size_t i = 0; i < padding_lo.size(); i++) {
|
||||
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
|
||||
}
|
||||
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
@ -1261,7 +1297,8 @@ void conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -1270,22 +1307,40 @@ void conv_1D_cpu(
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||
return explicit_gemm_conv_1D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, stream);
|
||||
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
|
||||
}
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_1D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
void conv_2D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -1295,18 +1350,35 @@ void conv_2D_cpu(
|
||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||
in_dilation[1] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_2D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
void conv_3D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
@ -1317,11 +1389,28 @@ void conv_3D_cpu(
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||
groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_3D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
wt_strides,
|
||||
wt_dilation,
|
||||
in_dilation,
|
||||
flip,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
|
@ -944,6 +944,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wt = arr_copy;
|
||||
}
|
||||
|
||||
auto padding_ = padding_lo_;
|
||||
|
||||
// 3D conv
|
||||
if (out.ndim() == 5) {
|
||||
conv_3D_gpu(
|
||||
|
@ -3974,6 +3974,7 @@ array conv_general(
|
||||
to_stream(s),
|
||||
stride,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
groups,
|
||||
|
@ -1055,7 +1055,8 @@ array conv_weight_backward_patches(
|
||||
const array& wt,
|
||||
const array& cotan,
|
||||
const std::vector<int>& kernel_strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
StreamOrDevice s) {
|
||||
// Resolve Padded input shapes and strides
|
||||
Shape padding_starts(in.ndim(), 0);
|
||||
@ -1064,9 +1065,9 @@ array conv_weight_backward_patches(
|
||||
|
||||
// padded shape
|
||||
for (int i = 1; i < in.ndim() - 1; i++) {
|
||||
in_padded_shape[i] += 2 * padding[i - 1];
|
||||
padding_ends[i] += padding[i - 1];
|
||||
padding_starts[i] += padding[i - 1];
|
||||
in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1];
|
||||
padding_ends[i] += padding_lo[i - 1];
|
||||
padding_starts[i] += padding_lo[i - 1];
|
||||
}
|
||||
|
||||
// padded strides (contiguous)
|
||||
@ -1078,9 +1079,16 @@ array conv_weight_backward_patches(
|
||||
// Pad input
|
||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||
Shape padding_(padding.begin(), padding.end());
|
||||
auto in_padded = pad(
|
||||
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
|
||||
Shape padding_lo_(padding_lo.begin(), padding_lo.end());
|
||||
Shape padding_hi_(padding_hi.begin(), padding_hi.end());
|
||||
auto in_padded =
|
||||
pad(in,
|
||||
padded_axes,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
array(0, in.dtype()),
|
||||
"constant",
|
||||
s);
|
||||
|
||||
// Resolve strided patches
|
||||
|
||||
@ -1147,16 +1155,16 @@ std::vector<array> Convolution::vjp(
|
||||
for (int a : argnums) {
|
||||
// Grads for input
|
||||
if (a == 0) {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
std::vector<int> padding_lo = padding_lo_;
|
||||
std::vector<int> padding_hi = padding_hi_;
|
||||
|
||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_lo[i] = wt_size - padding_[i] - 1;
|
||||
padding_lo[i] = wt_size - padding_lo_[i] - 1;
|
||||
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
padding_hi[i] = in_size - out_size + padding_[i];
|
||||
padding_hi[i] = in_size - out_size + padding_hi_[i];
|
||||
}
|
||||
|
||||
// Check for negative padding
|
||||
@ -1226,17 +1234,17 @@ std::vector<array> Convolution::vjp(
|
||||
|
||||
if (no_dilation && !flip_ && groups_ == 1) {
|
||||
auto grad = conv_weight_backward_patches(
|
||||
in, wt, cotan, kernel_strides_, padding_, stream());
|
||||
in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
|
||||
grads.push_back(grad);
|
||||
} else {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
std::vector<int> padding_lo = padding_lo_;
|
||||
std::vector<int> padding_hi = padding_hi_;
|
||||
|
||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_hi_[i] - 1;
|
||||
}
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||
@ -1332,7 +1340,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
||||
|
||||
bool Convolution::is_equivalent(const Primitive& other) const {
|
||||
const Convolution& c_other = static_cast<const Convolution&>(other);
|
||||
return padding_ == c_other.padding_ &&
|
||||
return padding_lo_ == c_other.padding_lo_ &&
|
||||
padding_hi_ == c_other.padding_hi_ &&
|
||||
kernel_strides_ == c_other.kernel_strides_ &&
|
||||
kernel_dilation_ == c_other.kernel_dilation_ &&
|
||||
input_dilation_ == c_other.input_dilation_ &&
|
||||
|
@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive {
|
||||
explicit Convolution(
|
||||
Stream stream,
|
||||
const std::vector<int>& kernel_strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& padding_lo,
|
||||
const std::vector<int>& padding_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
const int groups = 1,
|
||||
const bool flip = false)
|
||||
: UnaryPrimitive(stream),
|
||||
padding_(padding),
|
||||
padding_lo_(padding_lo),
|
||||
padding_hi_(padding_hi),
|
||||
kernel_strides_(kernel_strides),
|
||||
kernel_dilation_(kernel_dilation),
|
||||
input_dilation_(input_dilation),
|
||||
@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> padding_;
|
||||
std::vector<int> padding_lo_;
|
||||
std::vector<int> padding_hi_;
|
||||
std::vector<int> kernel_strides_;
|
||||
std::vector<int> kernel_dilation_;
|
||||
std::vector<int> input_dilation_;
|
||||
|
Loading…
Reference in New Issue
Block a user