// Copyright © 2023-2024 Apple Inc. #include #include #ifdef ACCELERATE_NEW_LAPACK #include #else #include #endif #include "mlx/backend/common/copy.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { namespace { /////////////////////////////////////////////////////////////////////////////// // Naive reference conv /////////////////////////////////////////////////////////////////////////////// template void slow_conv_1D( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { const T* start_wt_ptr = wt.data(); const T* in_ptr = in.data(); T* out_ptr = out.data(); const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int C = in.shape(2); // Input channels const int oH = out.shape(1); // Output spatial dim const int O = wt.shape(0); // Out channels const int wH = wt.shape(1); // Weight spatial dim const int groups = C / wt.shape(2); const int C_per_group = wt.shape(2); const int O_per_group = O / groups; const size_t in_stride_N = in.strides()[0]; const size_t in_stride_H = in.strides()[1]; const size_t in_stride_C = in.strides()[2]; const size_t wt_stride_O = wt.strides()[0]; const size_t wt_stride_H = wt.strides()[1]; const size_t wt_stride_C = wt.strides()[2]; const size_t out_stride_N = out.strides()[0]; const size_t out_stride_H = out.strides()[1]; const size_t out_stride_O = out.strides()[2]; for (int n = 0; n < N; ++n) { for (int oh = 0; oh < oH; ++oh) { for (int g = 0; g < groups; ++g) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; float r = 0.; for (int wh = 0; wh < wH; ++wh) { const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; int wh_flip = flip ? (wH - wh - 1) : wh; int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0]; auto ih_div = std::div(ih, in_dilation[0]); if (ih >= 0 && ih < iH && ih_div.rem == 0) { for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { r += static_cast( in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * static_cast(wt_ptr[(c % C_per_group) * wt_stride_C]); } // c } // ih check } // wh out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast(r); } // o } // g } // oh in_ptr += in_stride_N; out_ptr += out_stride_N; } // n } template void slow_conv_2D( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { const T* st_wt_ptr = wt.data(); const T* st_in_ptr = in.data(); T* st_out_ptr = out.data(); const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim const int C = in.shape(3); // In channels const int oH = out.shape(1); // Output spatial dim const int oW = out.shape(2); // Output spatial dim const int O = wt.shape(0); // Out channels const int wH = wt.shape(1); // Weight spatial dim const int wW = wt.shape(2); // Weight spatial dim const int groups = C / wt.shape(3); const int C_per_group = wt.shape(3); const int O_per_group = O / groups; const size_t in_stride_N = in.strides()[0]; const size_t in_stride_H = in.strides()[1]; const size_t in_stride_W = in.strides()[2]; const size_t in_stride_C = in.strides()[3]; const size_t wt_stride_O = wt.strides()[0]; const size_t wt_stride_H = wt.strides()[1]; const size_t wt_stride_W = wt.strides()[2]; const size_t wt_stride_C = wt.strides()[3]; const size_t out_stride_N = out.strides()[0]; const size_t out_stride_H = out.strides()[1]; const size_t out_stride_W = out.strides()[2]; const size_t out_stride_O = out.strides()[3]; bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; auto pt_conv_no_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { out_ptr += oh * out_stride_H + ow * out_stride_W; int ih_base = oh * wt_strides[0] - padding[0]; int iw_base = ow * wt_strides[1] - padding[1]; for (int g = 0; g < groups; ++g) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { float r = 0.; for (int wh = 0; wh < wH; ++wh) { for (int ww = 0; ww < wW; ++ww) { int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int ih = ih_base + wh_flip * wt_dilation[0]; int iw = iw_base + ww_flip * wt_dilation[1]; const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { r += static_cast(in_ptr_pt[c * in_stride_C]) * static_cast( wt_ptr_pt[(c % C_per_group) * wt_stride_C]); } // c } // ww } // wh out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o } // g }; int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; std::vector base_h(f_out_jump_h); std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_h; ++i) { int ih_loop = i * wt_strides[0] - padding[0] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[0] != 0) { wh_base++; ih_loop += jump_h; } base_h[i] = wh_base; } for (int j = 0; j < f_out_jump_w; ++j) { int iw_loop = j * wt_strides[1] - padding[1] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[1] != 0) { ww_base++; iw_loop += jump_w; } base_w[j] = ww_base; } auto pt_conv_all_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { out_ptr += oh * out_stride_H + ow * out_stride_W; int ih_base = oh * wt_strides[0] - padding[0]; int iw_base = ow * wt_strides[1] - padding[1]; int wh_base = base_h[oh % f_out_jump_h]; int ww_base = base_w[ow % f_out_jump_w]; for (int g = 0; g < groups; ++g) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { float r = 0.; for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int ih = ih_base + wh_flip * wt_dilation[0]; int iw = iw_base + ww_flip * wt_dilation[1]; if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { r += static_cast(in_ptr_pt[c * in_stride_C]) * static_cast( wt_ptr_pt[(c % C_per_group) * wt_stride_C]); } // c } // ih, iw check } // ww } // wh out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o } // g }; int oH_border_0 = 0; int oH_border_1 = is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; int oH_border_2 = std::max( oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); int oH_border_3 = oH; int oW_border_0 = 0; int oW_border_1 = is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; int oW_border_2 = std::max( oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { // Case 1: oh might put us out of bounds for (int oh = oH_border_0; oh < oH_border_1; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow } // oh // Case 2: oh in bounds for (int oh = oH_border_1; oh < oH_border_2; ++oh) { // Case a: ow might put us out of bounds for (int ow = oW_border_0; ow < oW_border_1; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow // Case b: ow in bounds for (int ow = oW_border_1; ow < oW_border_2; ++ow) { pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow // Case c: ow might put us out of bounds for (int ow = oW_border_2; ow < oW_border_3; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow } // oh // Case 3: oh might put us out of bounds for (int oh = oH_border_2; oh < oH_border_3; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); } // ow } // oh st_in_ptr += in_stride_N; st_out_ptr += out_stride_N; } // n } template void slow_conv_3D( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { const T* st_wt_ptr = wt.data(); const T* st_in_ptr = in.data(); T* st_out_ptr = out.data(); const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim const int oD = out.shape(1); // Output spatial dim const int oH = out.shape(2); // Output spatial dim const int oW = out.shape(3); // Output spatial dim const int O = wt.shape(0); // Out channels const int C = wt.shape(4); // In channels const int wD = wt.shape(1); // Weight spatial dim const int wH = wt.shape(2); // Weight spatial dim const int wW = wt.shape(3); // Weight spatial dim const size_t in_stride_N = in.strides()[0]; const size_t in_stride_D = in.strides()[1]; const size_t in_stride_H = in.strides()[2]; const size_t in_stride_W = in.strides()[3]; const size_t in_stride_C = in.strides()[4]; const size_t wt_stride_O = wt.strides()[0]; const size_t wt_stride_D = wt.strides()[1]; const size_t wt_stride_H = wt.strides()[2]; const size_t wt_stride_W = wt.strides()[3]; const size_t wt_stride_C = wt.strides()[4]; const size_t out_stride_N = out.strides()[0]; const size_t out_stride_D = out.strides()[1]; const size_t out_stride_H = out.strides()[2]; const size_t out_stride_W = out.strides()[3]; const size_t out_stride_O = out.strides()[4]; bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1; auto pt_conv_no_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int od, int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; int id_base = od * wt_strides[0] - padding[0]; int ih_base = oh * wt_strides[1] - padding[1]; int iw_base = ow * wt_strides[2] - padding[2]; for (int o = 0; o < O; ++o) { float r = 0.; for (int wd = 0; wd < wD; ++wd) { for (int wh = 0; wh < wH; ++wh) { for (int ww = 0; ww < wW; ++ww) { int wd_flip = flip ? wD - wd - 1 : wd; int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int id = id_base + wd_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[2]; const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; const T* in_ptr_pt = in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W; for (int c = 0; c < C; ++c) { r += static_cast(in_ptr_pt[0]) * static_cast(wt_ptr_pt[0]); in_ptr_pt += in_stride_C; wt_ptr_pt += wt_stride_C; } // c } // ww } // wh } // wd out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o }; int jump_d = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_h = flip ? -wt_dilation[1] : wt_dilation[1]; int jump_w = flip ? -wt_dilation[2] : wt_dilation[2]; int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0); int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0); int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0); int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2]; int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2]; std::vector base_d(f_out_jump_d); std::vector base_h(f_out_jump_h); std::vector base_w(f_out_jump_w); for (int i = 0; i < f_out_jump_d; ++i) { int id_loop = i * wt_strides[0] - padding[0] + init_d; int wd_base = 0; while (wd_base < wD && id_loop % in_dilation[0] != 0) { wd_base++; id_loop += jump_d; } base_d[i] = wd_base; } for (int i = 0; i < f_out_jump_h; ++i) { int ih_loop = i * wt_strides[1] - padding[1] + init_h; int wh_base = 0; while (wh_base < wH && ih_loop % in_dilation[1] != 0) { wh_base++; ih_loop += jump_h; } base_h[i] = wh_base; } for (int j = 0; j < f_out_jump_w; ++j) { int iw_loop = j * wt_strides[2] - padding[2] + init_w; int ww_base = 0; while (ww_base < wW && iw_loop % in_dilation[2] != 0) { ww_base++; iw_loop += jump_w; } base_w[j] = ww_base; } auto pt_conv_all_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int od, int oh, int ow) { out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; int id_base = od * wt_strides[0] - padding[0]; int ih_base = oh * wt_strides[1] - padding[1]; int iw_base = ow * wt_strides[2] - padding[2]; int wd_base = base_d[od % f_out_jump_d]; int wh_base = base_h[oh % f_out_jump_h]; int ww_base = base_w[ow % f_out_jump_w]; for (int o = 0; o < O; ++o) { float r = 0.; for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) { for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { int wd_flip = flip ? wD - wd - 1 : wd; int wh_flip = flip ? wH - wh - 1 : wh; int ww_flip = flip ? wW - ww - 1 : ww; int id = id_base + wd_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[2]; if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 && iw < iW) { const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W; int id_dil = !is_idil_one ? (id / in_dilation[0]) : id; int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih; int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw; const T* in_ptr_pt = in_ptr + id_dil * in_stride_D + ih_dil * in_stride_H + iw_dil * in_stride_W; for (int c = 0; c < C; ++c) { r += static_cast(in_ptr_pt[0]) * static_cast(wt_ptr_pt[0]); in_ptr_pt += in_stride_C; wt_ptr_pt += wt_stride_C; } // c } // iD, ih, iw check } // ww } // wh } // wd out_ptr[0] = static_cast(r); out_ptr += out_stride_O; wt_ptr += wt_stride_O; } // o }; int oD_border_0 = 0; int oD_border_1 = is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; int oD_border_2 = std::max( oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); int oD_border_3 = oD; int oH_border_0 = 0; int oH_border_1 = is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; int oH_border_2 = std::max( oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); int oH_border_3 = oH; int oW_border_0 = 0; int oW_border_1 = is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; int oW_border_2 = std::max( oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { // Case 1: od might put us out of bounds for (int od = oD_border_0; od < oD_border_1; ++od) { for (int oh = 0; oh < oH; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh } // od // Case 2: od in bounds for (int od = oD_border_1; od < oD_border_2; ++od) { // Case 2.1: oh might put us out of bounds for (int oh = oH_border_0; oh < oH_border_1; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh // Case 2.2: oh in bounds for (int oh = oH_border_1; oh < oH_border_2; ++oh) { // Case 2.2.1: ow might put us out of bounds for (int ow = oW_border_0; ow < oW_border_1; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow // Case 2.2.2: ow in bounds for (int ow = oW_border_1; ow < oW_border_2; ++ow) { pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow // Case 2.2.3: ow might put us out of bounds for (int ow = oW_border_2; ow < oW_border_3; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh // Case 2.3: oh might put us out of bounds for (int oh = oH_border_2; oh < oH_border_3; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh } // od // Case 3: od might put us out of bounds for (int od = oD_border_2; od < oD_border_3; ++od) { for (int oh = 0; oh < oH; ++oh) { for (int ow = 0; ow < oW; ++ow) { pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow); } // ow } // oh } // od st_in_ptr += in_stride_N; st_out_ptr += out_stride_N; } // n } void dispatch_slow_conv_1D( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { if (in.dtype() == float32) { return slow_conv_1D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == float16) { return slow_conv_1D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == bfloat16) { return slow_conv_1D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); } } void dispatch_slow_conv_2D( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { if (in.dtype() == float32) { return slow_conv_2D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == float16) { return slow_conv_2D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == bfloat16) { return slow_conv_2D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); } } void dispatch_slow_conv_3D( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { if (in.dtype() == float32) { return slow_conv_3D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == float16) { return slow_conv_3D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == bfloat16) { return slow_conv_3D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); } } /////////////////////////////////////////////////////////////////////////////// // Explicit gemm conv /////////////////////////////////////////////////////////////////////////////// template void flip_spatial_dims_inplace(array& wt) { T* x = wt.data(); size_t out_channels = wt.shape(0); size_t in_channels = wt.shape(-1); // Calculate the total size of the spatial dimensions int spatial_size = 1; for (int d = 1; d < wt.ndim() - 1; ++d) { spatial_size *= wt.shape(d); } for (size_t i = 0; i < out_channels; i++) { T* top = x + i * spatial_size * in_channels; T* bottom = x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels; for (size_t j = 0; j < spatial_size / 2; j++) { for (size_t k = 0; k < in_channels; k++) { std::swap(top[k], bottom[k]); } top += in_channels; bottom -= in_channels; } } } void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = in.shape(1); // Input spatial dim const int C = in.shape(2); // Input channels const int oH = out.shape(1); // Output spatial dim const int O = wt.shape(0); // Out channels const int wH = wt.shape(1); // Weight spatial dim const int groups = C / wt.shape(2); const int C_per_group = wt.shape(2); const int O_per_group = O / groups; auto conv_dtype = float32; // Pad input std::vector padded_shape = {N, iH + 2 * padding[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros copy(array(0, conv_dtype), in_padded, CopyType::Scalar); // Pick input slice from padded size_t data_offset = padding[0] * in_padded.strides()[1]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, in_padded.strides(), in_padded.flags(), in_padded_slice.size(), data_offset); // Copy input values into the slice copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); // Make strided view std::vector strided_shape = {N, oH, wH, C}; std::vector strided_strides = { in_padded.strides()[0], in_padded.strides()[1] * wt_strides[0], in_padded.strides()[1], in_padded.strides()[2]}; auto flags = in_padded.flags(); if (groups > 1) { // Transpose the last two dimensions for grouped convolutions std::swap(strided_shape[2], strided_shape[3]); std::swap(strided_strides[2], strided_strides[3]); } array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); in_strided_view.copy_shared_buffer( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view std::vector strided_reshape = {N * oH, wH * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); copy(in_strided_view, in_strided, CopyType::General); // Check wt dtype and prepare auto gemm_wt = wt; auto gemm_out = out; if (groups > 1) { // Transpose the last two dimensions for grouped convolutions array wt_transpose( {wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {}); wt_transpose.copy_shared_buffer( wt, {wt.strides(0), wt.strides(2), wt.strides(1)}, wt.flags(), wt.size(), 0); gemm_wt = array(wt_transpose.shape(), float32, nullptr, {}); copy(wt_transpose, gemm_wt, CopyType::General); } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); copy(wt, gemm_wt, ctype); } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); } for (int g = 0; g < groups; ++g) { // Perform gemm cblas_sgemm( CblasRowMajor, CblasNoTrans, // no trans A CblasTrans, // transB strided_reshape[0], // M O_per_group, // N C_per_group * wH, // K 1.0f, // alpha in_strided.data() + g * C_per_group * wH, // A wH * C, // lda gemm_wt.data() + g * O_per_group * C_per_group * wH, // B wH * C_per_group, // ldb 0.0f, // beta gemm_out.data() + g * O_per_group, // C O // ldc ); // Copy results if needed if (out.dtype() != float32) { copy(gemm_out, out, CopyType::Vector); } } } void explicit_gemm_conv_2D_cpu( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = in.shape(1); // Input spatial dim const int iW = in.shape(2); // Input spatial dim const int oH = out.shape(1); // Output spatial dim const int oW = out.shape(2); // Output spatial dim const int O = wt.shape(0); // Out channels const int C = wt.shape(3); // In channels const int wH = wt.shape(1); // Weight spatial dim const int wW = wt.shape(2); // Weight spatial dim auto conv_dtype = out.dtype(); // Pad input std::vector padded_shape = { N, iH + 2 * padding[0], iW + 2 * padding[1], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros copy(array(0, conv_dtype), in_padded, CopyType::Scalar); // Pick input slice from padded size_t data_offset = padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, in_padded.strides(), in_padded.flags(), in_padded_slice.size(), data_offset); // Copy input values into the slice copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); // Make strided view std::vector strided_shape = {N, oH, oW, wH, wW, C}; std::vector strided_strides = { in_padded.strides()[0], in_padded.strides()[1] * wt_strides[0], in_padded.strides()[2] * wt_strides[1], in_padded.strides()[1], in_padded.strides()[2], in_padded.strides()[3]}; auto flags = in_padded.flags(); array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); in_strided_view.copy_shared_buffer( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view std::vector strided_reshape = {N * oH * oW, wH * wW * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); copy(in_strided_view, in_strided, CopyType::General); // Check wt dtype and prepare auto gemm_wt = wt; auto gemm_out = out; if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); copy(wt, gemm_wt, ctype); } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); } // Perform gemm cblas_sgemm( CblasRowMajor, CblasNoTrans, // no trans A CblasTrans, // transB strided_reshape[0], // M O, // N strided_reshape[1], // K 1.0f, // alpha in_strided.data(), strided_reshape[1], // lda gemm_wt.data(), strided_reshape[1], // ldb 0.0f, // beta gemm_out.data(), O // ldc ); // Copy results if needed if (out.dtype() != float32) { copy(gemm_out, out, CopyType::Vector); } } void explicit_gemm_conv_ND_cpu( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const bool flip) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const auto iDim = std::vector( in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim const auto oDim = std::vector( out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim const int O = wt.shape(0); // Out channels const int C = wt.shape(-1); // In channels const auto wDim = std::vector( wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim auto conv_dtype = float32; // Pad input std::vector padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { padded_shape[i + 1] = iDim[i] + 2 * padding[i]; } padded_shape.back() = C; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros copy(array(0, conv_dtype), in_padded, CopyType::Scalar); // Pick input slice from padded size_t data_offset = 0; for (size_t i = 0; i < padding.size(); i++) { data_offset += padding[i] * in_padded.strides()[i + 1]; } array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); in_padded_slice.copy_shared_buffer( in_padded, in_padded.strides(), in_padded.flags(), in_padded_slice.size(), data_offset); // Copy input values into the slice copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); // Make strided view std::vector strided_shape(oDim.size() + wDim.size() + 2); strided_shape.front() = N; for (size_t i = 0; i < oDim.size(); i++) { strided_shape[i + 1] = oDim[i]; } for (size_t i = 0; i < wDim.size(); i++) { strided_shape[i + 1 + oDim.size()] = wDim[i]; } strided_shape.back() = C; std::vector strided_strides(in.shape().size() * 2 - 2); strided_strides[0] = in_padded.strides()[0]; for (size_t i = 0; i < wt_strides.size(); i++) { strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i]; } for (size_t i = 1; i < in_padded.strides().size(); i++) { strided_strides[i + wt_strides.size()] = in_padded.strides()[i]; } auto flags = in_padded.flags(); array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); in_strided_view.copy_shared_buffer( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view std::vector strided_reshape = {N, C}; for (const auto& o : oDim) { strided_reshape[0] *= o; } for (const auto& w : wDim) { strided_reshape[1] *= w; } array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); copy(in_strided_view, in_strided, CopyType::General); // Check wt dtype and prepare auto gemm_wt = wt; auto gemm_out = out; if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); copy(wt, gemm_wt, ctype); } if (flip) { auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {}); copy(gemm_wt, gemm_wt_, CopyType::Vector); flip_spatial_dims_inplace(gemm_wt_); gemm_wt = gemm_wt_; } if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); } // Perform gemm cblas_sgemm( CblasRowMajor, CblasNoTrans, // no trans A CblasTrans, // transB strided_reshape[0], // M O, // N strided_reshape[1], // K 1.0f, // alpha in_strided.data(), strided_reshape[1], // lda gemm_wt.data(), strided_reshape[1], // ldb 0.0f, // beta gemm_out.data(), O // ldc ); // Copy results if needed if (out.dtype() != float32) { copy(gemm_out, out, CopyType::Vector); } } /////////////////////////////////////////////////////////////////////////////// // Conv routing /////////////////////////////////////////////////////////////////////////////// void conv_1D_cpu( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( in, wt, out, padding, wt_strides, wt_dilation); } if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( in, wt, out, padding, wt_strides, wt_dilation, flip); } return dispatch_slow_conv_1D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } void conv_2D_cpu( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( in, wt, out, padding, wt_strides, wt_dilation, flip); } return dispatch_slow_conv_2D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } void conv_3D_cpu( const array& in, const array& wt, array out, const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && groups == 1) { return explicit_gemm_conv_ND_cpu( in, wt, out, padding, wt_strides, wt_dilation, flip); } return dispatch_slow_conv_3D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } } // namespace void Convolution::eval(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); auto& in = inputs[0]; auto& wt = inputs[1]; // 3D convolution if (in.ndim() == (3 + 2)) { return conv_3D_cpu( in, wt, out, padding_, kernel_strides_, kernel_dilation_, input_dilation_, flip_); } // 2D convolution else if (in.ndim() == (2 + 2)) { return conv_2D_cpu( in, wt, out, padding_, kernel_strides_, kernel_dilation_, input_dilation_, flip_); } // 1D convolution else if (in.ndim() == (1 + 2)) { return conv_1D_cpu( in, wt, out, padding_, kernel_strides_, kernel_dilation_, input_dilation_, flip_); } // Throw error else { std::ostringstream msg; msg << "[Convolution::eval] Convolution currently only supports" << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2 << " spatial dimensions"; throw std::invalid_argument(msg.str()); } } } // namespace mlx::core