diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index d52f92f8b..e5636b3b8 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -22,7 +22,8 @@ void slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -60,7 +61,8 @@ void slow_conv_1D( out_stride_O = out.strides()[2], flip, - padding = padding[0], + padding_lo = padding_lo[0], + padding_hi = padding_hi[0], wt_stride = wt_strides[0], wt_dilation = wt_dilation[0], in_dilation = in_dilation[0]]() mutable { @@ -77,7 +79,7 @@ void slow_conv_1D( const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; int wh_flip = flip ? (wH - wh - 1) : wh; - int ih = oh * wt_stride - padding + wh_flip * wt_dilation; + int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation; auto ih_div = std::div(ih, in_dilation); @@ -109,7 +111,8 @@ void slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -120,230 +123,235 @@ void slow_conv_2D( encoder.set_input_array(wt); encoder.set_output_array(out); - encoder.dispatch([st_wt_ptr = wt.data(), - st_in_ptr = in.data(), - st_out_ptr = out.data(), + encoder.dispatch( + [st_wt_ptr = wt.data(), + st_in_ptr = in.data(), + st_out_ptr = out.data(), - N = in.shape( - 0), // Batch size, should be the same as out.shape(0) - iH = 1 + - in_dilation[0] * (in.shape(1) - 1), // Input spatial dim - iW = 1 + - in_dilation[1] * (in.shape(2) - 1), // Input spatial dim - C = in.shape(3), // In channels - oH = out.shape(1), // Output spatial dim - oW = out.shape(2), // Output spatial dim - O = wt.shape(0), // Out channels - wH = wt.shape(1), // Weight spatial dim - wW = wt.shape(2), // Weight spatial dim + N = in.shape(0), // Batch size, should be the same as out.shape(0) + iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim + C = in.shape(3), // In channels + oH = out.shape(1), // Output spatial dim + oW = out.shape(2), // Output spatial dim + O = wt.shape(0), // Out channels + wH = wt.shape(1), // Weight spatial dim + wW = wt.shape(2), // Weight spatial dim - groups = in.shape(3) / wt.shape(3), - C_per_group = wt.shape(3), + groups = in.shape(3) / wt.shape(3), + C_per_group = wt.shape(3), - in_stride_N = in.strides()[0], - in_stride_H = in.strides()[1], - in_stride_W = in.strides()[2], - in_stride_C = in.strides()[3], + in_stride_N = in.strides()[0], + in_stride_H = in.strides()[1], + in_stride_W = in.strides()[2], + in_stride_C = in.strides()[3], - wt_stride_O = wt.strides()[0], - wt_stride_H = wt.strides()[1], - wt_stride_W = wt.strides()[2], - wt_stride_C = wt.strides()[3], + wt_stride_O = wt.strides()[0], + wt_stride_H = wt.strides()[1], + wt_stride_W = wt.strides()[2], + wt_stride_C = wt.strides()[3], - out_stride_N = out.strides()[0], - out_stride_H = out.strides()[1], - out_stride_W = out.strides()[2], - out_stride_O = out.strides()[3], + out_stride_N = out.strides()[0], + out_stride_H = out.strides()[1], + out_stride_W = out.strides()[2], + out_stride_O = out.strides()[3], - padding, - wt_strides, - wt_dilation, - in_dilation, - flip]() mutable { - bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip]() mutable { + bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; - const int O_per_group = O / groups; - auto pt_conv_no_checks = [&](const T* in_ptr, - const T* wt_ptr, - T* out_ptr, - int oh, - int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + const int O_per_group = O / groups; + auto pt_conv_no_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + ih * in_stride_H + iw * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c - } // ww - } // wh + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; - int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; + int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; - int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); - int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); + int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); - int f_wgt_jump_h = - std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; - int f_wgt_jump_w = - std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + int f_wgt_jump_h = + std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_w = + std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; - int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; - int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + int f_out_jump_h = + std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_w = + std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; - std::vector base_h(f_out_jump_h); - std::vector base_w(f_out_jump_w); + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); - for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[0] - padding[0] + init_h; + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h; - int wh_base = 0; - while (wh_base < wH && ih_loop % in_dilation[0] != 0) { - wh_base++; - ih_loop += jump_h; - } + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[0] != 0) { + wh_base++; + ih_loop += jump_h; + } - base_h[i] = wh_base; - } + base_h[i] = wh_base; + } - for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[1] - padding[1] + init_w; + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w; - int ww_base = 0; - while (ww_base < wW && iw_loop % in_dilation[1] != 0) { - ww_base++; - iw_loop += jump_w; - } + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[1] != 0) { + ww_base++; + iw_loop += jump_w; + } - base_w[j] = ww_base; - } + base_w[j] = ww_base; + } - auto pt_conv_all_checks = - [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { - out_ptr += oh * out_stride_H + ow * out_stride_W; + auto pt_conv_all_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; - int ih_base = oh * wt_strides[0] - padding[0]; - int iw_base = ow * wt_strides[1] - padding[1]; + int ih_base = oh * wt_strides[0] - padding_lo[0]; + int iw_base = ow * wt_strides[1] - padding_lo[1]; - int wh_base = base_h[oh % f_out_jump_h]; - int ww_base = base_w[ow % f_out_jump_w]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; - for (int g = 0; g < groups; ++g) { - for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; - const T* in_ptr_pt = - in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; + const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H + + iw_dil * in_stride_W; - for (int c = g * C_per_group; c < (g + 1) * C_per_group; - ++c) { - r += static_cast(in_ptr_pt[c * in_stride_C]) * - static_cast( - wt_ptr_pt[(c % C_per_group) * wt_stride_C]); - } // c + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c - } // ih, iw check - } // ww - } // wh + } // ih, iw check + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o - } // g - }; + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g + }; - int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; - int oH_border_2 = std::max( - oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); - int oH_border_3 = oH; + int oH_border_0 = 0; + int oH_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oH; + int oH_border_2 = std::max( + oH_border_1, + (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]); + int oH_border_3 = oH; - int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; - int oW_border_2 = std::max( - oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); - int oW_border_3 = oW; + int oW_border_0 = 0; + int oW_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oW; + int oW_border_2 = std::max( + oW_border_1, + (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]); + int oW_border_3 = oW; - for (int n = 0; n < N; ++n) { - // Case 1: oh might put us out of bounds - for (int oh = oH_border_0; oh < oH_border_1; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + for (int n = 0; n < N; ++n) { + // Case 1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - // Case 2: oh in bounds - for (int oh = oH_border_1; oh < oH_border_2; ++oh) { - // Case a: ow might put us out of bounds - for (int ow = oW_border_0; ow < oW_border_1; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case 2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case a: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case b: ow in bounds - for (int ow = oW_border_1; ow < oW_border_2; ++ow) { - pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case b: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - // Case c: ow might put us out of bounds - for (int ow = oW_border_2; ow < oW_border_3; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow + // Case c: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow - } // oh + } // oh - // Case 3: oh might put us out of bounds - for (int oh = oH_border_2; oh < oH_border_3; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); - } // ow - } // oh + // Case 3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh - st_in_ptr += in_stride_N; - st_out_ptr += out_stride_N; + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; - } // n - }); + } // n + }); } template @@ -351,7 +359,8 @@ void slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -400,7 +409,8 @@ void slow_conv_3D( out_stride_H = out.strides()[2], out_stride_W = out.strides()[3], out_stride_O = out.strides()[4], - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -415,9 +425,9 @@ void slow_conv_3D( int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; for (int o = 0; o < O; ++o) { float r = 0.; @@ -478,7 +488,7 @@ void slow_conv_3D( std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_d; ++i) { - int id_loop = i * wt_strides[0] - padding[0] + init_d; + int id_loop = i * wt_strides[0] - padding_lo[0] + init_d; int wd_base = 0; while (wd_base < wD && id_loop % in_dilation[0] != 0) { @@ -490,7 +500,7 @@ void slow_conv_3D( } for (int i = 0; i < f_out_jump_h; ++i) { - int ih_loop = i * wt_strides[1] - padding[1] + init_h; + int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[1] != 0) { @@ -502,7 +512,7 @@ void slow_conv_3D( } for (int j = 0; j < f_out_jump_w; ++j) { - int iw_loop = j * wt_strides[2] - padding[2] + init_w; + int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[2] != 0) { @@ -521,9 +531,9 @@ void slow_conv_3D( int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; - int id_base = od * wt_strides[0] - padding[0]; - int ih_base = oh * wt_strides[1] - padding[1]; - int iw_base = ow * wt_strides[2] - padding[2]; + int id_base = od * wt_strides[0] - padding_lo[0]; + int ih_base = oh * wt_strides[1] - padding_lo[1]; + int iw_base = ow * wt_strides[2] - padding_lo[2]; int wd_base = base_d[od % f_out_jump_d]; int wh_base = base_h[oh % f_out_jump_h]; @@ -573,24 +583,30 @@ void slow_conv_3D( }; int oD_border_0 = 0; - int oD_border_1 = - is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; + int oD_border_1 = is_idil_one + ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0]) + : oD; int oD_border_2 = std::max( - oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); + oD_border_1, + (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]); int oD_border_3 = oD; int oH_border_0 = 0; - int oH_border_1 = - is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; + int oH_border_1 = is_idil_one + ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1]) + : oH; int oH_border_2 = std::max( - oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); + oH_border_1, + (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]); int oH_border_3 = oH; int oW_border_0 = 0; - int oW_border_1 = - is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; + int oW_border_1 = is_idil_one + ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2]) + : oW; int oW_border_2 = std::max( - oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); + oW_border_1, + (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { @@ -658,7 +674,8 @@ void dispatch_slow_conv_1D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -669,7 +686,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -680,7 +698,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -691,7 +710,8 @@ void dispatch_slow_conv_1D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -707,7 +727,8 @@ void dispatch_slow_conv_2D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -718,7 +739,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -729,7 +751,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -740,7 +763,8 @@ void dispatch_slow_conv_2D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -756,7 +780,8 @@ void dispatch_slow_conv_3D( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -767,7 +792,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -778,7 +804,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -789,7 +816,8 @@ void dispatch_slow_conv_3D( in, wt, out, - padding, + padding_lo, + padding_hi, wt_strides, wt_dilation, in_dilation, @@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], C}; + Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = padding[0] * in_padded.strides()[1]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, Stream stream) { @@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu( auto& encoder = cpu::get_command_encoder(stream); // Pad input - Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; + Shape padded_shape = { + N, + iH + padding_lo[0] + padding_hi[0], + iW + padding_lo[1] + padding_hi[1], + C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu( copy(temps.back(), in_padded, CopyType::Scalar, stream); // Pick input slice from padded - size_t data_offset = - padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; + size_t data_offset = padding_lo[0] * in_padded.strides()[1] + + padding_lo[1] * in_padded.strides()[2]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const bool flip, @@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu( Shape padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { - padded_shape[i + 1] = iDim[i] + 2 * padding[i]; + padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i]; } padded_shape.back() = C; array in_padded(padded_shape, conv_dtype, nullptr, {}); @@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu( // Pick input slice from padded size_t data_offset = 0; - for (size_t i = 0; i < padding.size(); i++) { - data_offset += padding[i] * in_padded.strides()[i + 1]; + for (size_t i = 0; i < padding_lo.size(); i++) { + data_offset += padding_lo[i] * in_padded.strides()[i + 1]; } + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, @@ -1261,7 +1297,8 @@ void conv_1D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1270,22 +1307,40 @@ void conv_1D_cpu( const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( - in, wt, out, padding, wt_strides, wt_dilation, stream); + in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_2D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1295,18 +1350,35 @@ void conv_2D_cpu( if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } - return dispatch_slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } void conv_3D_cpu( const array& in, const array& wt, array out, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, @@ -1317,11 +1389,28 @@ void conv_3D_cpu( in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( - in, wt, out, padding, wt_strides, wt_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + flip, + stream); } return dispatch_slow_conv_3D( - in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); + in, + wt, + out, + padding_lo, + padding_hi, + wt_strides, + wt_dilation, + in_dilation, + flip, + stream); } } // namespace @@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index ae31a6cff..35ed3d44e 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -952,7 +952,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -967,7 +967,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -983,7 +983,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4aa5e88b7..e8c260425 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3974,6 +3974,7 @@ array conv_general( to_stream(s), stride, padding_lo, + padding_hi, kernel_dilation, input_dilation, groups, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 7288a4885..03ca06bdd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1055,7 +1055,8 @@ array conv_weight_backward_patches( const array& wt, const array& cotan, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, StreamOrDevice s) { // Resolve Padded input shapes and strides Shape padding_starts(in.ndim(), 0); @@ -1064,9 +1065,9 @@ array conv_weight_backward_patches( // padded shape for (int i = 1; i < in.ndim() - 1; i++) { - in_padded_shape[i] += 2 * padding[i - 1]; - padding_ends[i] += padding[i - 1]; - padding_starts[i] += padding[i - 1]; + in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1]; + padding_ends[i] += padding_lo[i - 1]; + padding_starts[i] += padding_lo[i - 1]; } // padded strides (contiguous) @@ -1078,9 +1079,16 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_(padding.begin(), padding.end()); - auto in_padded = pad( - in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); + Shape padding_lo_(padding_lo.begin(), padding_lo.end()); + Shape padding_hi_(padding_hi.begin(), padding_hi.end()); + auto in_padded = + pad(in, + padded_axes, + padding_lo_, + padding_hi_, + array(0, in.dtype()), + "constant", + s); // Resolve strided patches @@ -1147,16 +1155,16 @@ std::vector Convolution::vjp( for (int a : argnums) { // Grads for input if (a == 0) { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; for (int i = 0; i < padding_lo.size(); ++i) { int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_lo[i] = wt_size - padding_[i] - 1; + padding_lo[i] = wt_size - padding_lo_[i] - 1; int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding_[i]; + padding_hi[i] = in_size - out_size + padding_hi_[i]; } // Check for negative padding @@ -1226,18 +1234,12 @@ std::vector Convolution::vjp( if (no_dilation && !flip_ && groups_ == 1) { auto grad = conv_weight_backward_patches( - in, wt, cotan, kernel_strides_, padding_, stream()); + in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_; - std::vector padding_hi = padding_; + std::vector padding_lo = padding_lo_; + std::vector padding_hi = padding_hi_; - for (int i = 0; i < padding_hi.size(); ++i) { - int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); - int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; - } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1283,7 +1285,8 @@ std::pair, std::vector> Convolution::vmap( in, w, kernel_strides_, - padding_, + padding_lo_, + padding_hi_, kernel_dilation_, input_dilation_, groups, @@ -1332,7 +1335,8 @@ std::pair, std::vector> Convolution::vmap( bool Convolution::is_equivalent(const Primitive& other) const { const Convolution& c_other = static_cast(other); - return padding_ == c_other.padding_ && + return padding_lo_ == c_other.padding_lo_ && + padding_hi_ == c_other.padding_hi_ && kernel_strides_ == c_other.kernel_strides_ && kernel_dilation_ == c_other.kernel_dilation_ && input_dilation_ == c_other.input_dilation_ && diff --git a/mlx/primitives.h b/mlx/primitives.h index 3753e43c5..2caed8477 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive { explicit Convolution( Stream stream, const std::vector& kernel_strides, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& kernel_dilation, const std::vector& input_dilation, const int groups = 1, const bool flip = false) : UnaryPrimitive(stream), - padding_(padding), + padding_lo_(padding_lo), + padding_hi_(padding_hi), kernel_strides_(kernel_strides), kernel_dilation_(kernel_dilation), input_dilation_(input_dilation), @@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - padding_, + padding_lo_, + padding_hi_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive { } private: - std::vector padding_; + std::vector padding_lo_; + std::vector padding_hi_; std::vector kernel_strides_; std::vector kernel_dilation_; std::vector input_dilation_; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 671c86a32..35dcf42ac 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase): atol=2e-5 if dtype == np.float32 else 5e-4, ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_asymmetric_padding(self): + inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32) + strides = (2, 2, 2) + + pt_out = torch.conv3d( + torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)), + torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)), + stride=strides, + padding=2, + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=strides, + padding=([0, 0, 0], [1, 1, 1]), + ) + + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + + inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32) + + pt_out = torch.conv2d( + torch.permute(torch.tensor(inputs), (0, 3, 1, 2)), + torch.permute(torch.tensor(kernel), (0, 3, 1, 2)), + stride=1, + padding=(1, 0), + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=1, + padding=([0, 0], [1, 0]), + ) + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + if __name__ == "__main__": unittest.main()